diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8556c601c..7255a536e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -142,6 +142,10 @@ jobs: - name: Test run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test + - name: Check binary compatibility + if: matrix.java == 'corretto@11' && matrix.os == 'ubuntu-22.04' + run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues + - name: Generate API documentation if: matrix.java == 'corretto@11' && matrix.os == 'ubuntu-22.04' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc @@ -151,11 +155,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-hikari/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: mkdir -p module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-authentication-plugin/native/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-authentication-plugin/js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-authentication-plugin/jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-hikari/target module/ldbc-sql/.jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-dsl/.jvm/target project/target + run: tar cf targets.tar module/ldbc-query-builder/.js/target module/ldbc-codegen/native/target module/jdbc-connector/.jvm/target module/ldbc-authentication-plugin/native/target module/ldbc-query-builder/.native/target module/ldbc-codegen/jvm/target module/ldbc-query-builder/.jvm/target module/ldbc-dsl/.native/target module/ldbc-connector/js/target module/ldbc-codegen/js/target module/ldbc-zio-interop/.jvm/target module/ldbc-core/.native/target module/ldbc-sql/.js/target module/ldbc-authentication-plugin/js/target module/ldbc-aws-authentication-plugin/jvm/target module/ldbc-statement/.native/target module/ldbc-core/.js/target module/ldbc-schema/.js/target module/ldbc-sql/.native/target module/ldbc-zio-interop/.js/target module/ldbc-schema/.native/target module/ldbc-statement/.jvm/target module/ldbc-core/.jvm/target module/ldbc-dsl/.js/target module/ldbc-sql/.jvm/target module/ldbc-authentication-plugin/jvm/target module/ldbc-statement/.js/target module/ldbc-connector/native/target module/ldbc-connector/jvm/target module/ldbc-schema/.jvm/target plugin/target module/ldbc-aws-authentication-plugin/native/target module/ldbc-dsl/.jvm/target module/ldbc-aws-authentication-plugin/js/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) @@ -390,7 +394,7 @@ jobs: - name: Submit Dependencies uses: scalacenter/sbt-dependency-submission@v2 with: - modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 + modules-ignore: ldbcjs_3 ldbcjs_3 otel_3 mcp-ldbc-document-server_sjs1_3 docs_3 docs_3 zio_3 ldbcnative_3 ldbcnative_3 ldbcjvm_3 ldbcjvm_3 hikaricp_3 tests_sjs1_3 tests_sjs1_3 http4s_3 aws-iam-auth_3 tests_3 tests_3 benchmark_3 benchmark_3 tests_native0.4_3 tests_native0.4_3 configs-ignore: test scala-tool scala-doc-tool test-internal validate-steward: diff --git a/.scalafmt.conf b/.scalafmt.conf index cbe941411..3bf8cf582 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -104,7 +104,6 @@ rewrite { ["org\\.openjdk\\..*"], ["org\\.apache\\..*"], ["org\\.slf4j\\..*"], - ["org\\.schemaspy\\..*"], ["com\\.mysql\\..*"], ["com\\.zaxxer\\..*"], ["com\\.comcast\\..*"], @@ -117,7 +116,6 @@ rewrite { ["laika\\..*"], ["org\\.typelevel\\..*"], ["org\\.scalatest\\..*"], - ["org\\.specs2\\..*"], ["munit\\..*"], ["slick\\..*"], ["doobie\\..*"] @@ -130,8 +128,6 @@ rewrite { ["ldbc\\.codegen\\..*"], ["ldbc\\.connector\\..*"], ["jdbc\\.connector\\..*"], - ["ldbc\\.hikari\\..*"], - ["ldbc\\.schemaspy\\..*"], ["ldbc\\..*"], [".*"], ] diff --git a/README.md b/README.md index b1e77c3f2..8223539b1 100644 --- a/README.md +++ b/README.md @@ -30,17 +30,17 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | Module / Platform | JVM | Scala Native | Scala.js | Scaladoc | |----------------------|:---:|:------------:|:--------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------| -| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | -| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | -| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | -| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | -| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | -| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | -| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | -| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | -| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | -| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | +| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | +| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | +| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | +| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | +| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | +| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | +| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | +| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | +| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Performance @@ -229,28 +229,25 @@ val result: IO[List[User]] = ## How to use with ZIO -Although ldbc was created to run on the Cats Effect, can also be used in conjunction with ZIO by using [ZIO Interop Cats](https://github.com/zio/interop-cats). +Although ldbc was created to run on the Cats Effect, can also be used in conjunction with ZIO by using `ldbc-zio-interop`. > [!CAUTION] > Although ldbc supports three platforms, Note that ZIO Interop Cats does not currently support Scala Native. ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` The following is sample code for using ldbc with ZIO. ```scala 3 import zio.* -import zio.interop.catz.* -object Main extends ZIOAppDefault: +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given cats.effect.std.UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given fs2.hashing.Hashing[Task] = fs2.hashing.Hashing.forSync[Task] - given fs2.io.net.Network[Task] = fs2.io.net.Network.forAsync[Task] +object Main extends ZIOAppDefault: private val datasource = MySQLDataSource diff --git a/build.sbt b/build.sbt index d32c547e1..cda5f4b30 100644 --- a/build.sbt +++ b/build.sbt @@ -6,7 +6,6 @@ import com.typesafe.tools.mima.core.* import BuildSettings.* -import Dependencies.* import Implicits.* import JavaVersions.* import ProjectKeys.* @@ -31,7 +30,6 @@ ThisBuild / githubWorkflowBuildPostamble += dockerStop ThisBuild / githubWorkflowTargetBranches := Seq("**") ThisBuild / githubWorkflowPublishTargetBranches := Seq(RefPredicate.StartsWith(Ref.Tag("v"))) ThisBuild / tlSitePublishBranch := None -ThisBuild / tlCiMimaBinaryIssueCheck := false lazy val sql = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Pure) @@ -122,6 +120,21 @@ lazy val jdbcConnector = crossProject(JVMPlatform) .defaultSettings .dependsOn(core) +lazy val authenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) + .crossType(CrossType.Full) + .module("authentication-plugin", "MySQL authentication plugin written in pure Scala3") + .settings( + libraryDependencies ++= Seq( + "org.typelevel" %%% "cats-core" % "2.10.0", + "org.scodec" %%% "scodec-bits" % "1.1.38" + ) + ) + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) + ) + .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) + .nativeSettings(Test / nativeBrewFormulas += "s2n") + lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) .module("connector", "MySQL connector written in pure Scala3") @@ -143,19 +156,25 @@ lazy val connector = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) .nativeSettings(Test / nativeBrewFormulas += "s2n") - .dependsOn(core) + .dependsOn(core, authenticationPlugin) -lazy val hikari = LepusSbtProject("ldbc-hikari", "module/ldbc-hikari") - .settings(description := "Project to build HikariCP") +lazy val awsAuthenticationPlugin = crossProject(JVMPlatform, JSPlatform, NativePlatform) + .crossType(CrossType.Full) + .module("aws-authentication-plugin", "Project for the plugin used with Aurora IAM authentication") .settings( - onLoadMessage := s"${ scala.Console.RED }WARNING: This project is deprecated and will be removed in future versions.${ scala.Console.RESET }", libraryDependencies ++= Seq( - catsEffect, - typesafeConfig, - hikariCP - ) ++ specs2 + "co.fs2" %%% "fs2-core" % "3.12.2", + "co.fs2" %%% "fs2-io" % "3.12.2", + "io.github.cquiroz" %%% "scala-java-time" % "2.5.0", + "org.typelevel" %%% "munit-cats-effect" % "2.1.0" % Test + ) + ) + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) ) - .dependsOn(dsl.jvm) + .nativeEnablePlugins(ScalaNativeBrewedConfigPlugin) + .nativeSettings(Test / nativeBrewFormulas += "s2n") + .dependsOn(authenticationPlugin) lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") .settings(description := "Projects that provide sbt plug-ins") @@ -168,6 +187,22 @@ lazy val plugin = LepusSbtPluginProject("ldbc-plugin", "plugin") ) }.taskValue) +lazy val zioInterop = crossProject(JVMPlatform, JSPlatform) + .crossType(CrossType.Pure) + .module("zio-interop", "Projects that provide a way to connect to the database for ZIO") + .settings( + libraryDependencies ++= Seq( + "dev.zio" %%% "zio" % "2.1.21", + "dev.zio" %%% "zio-interop-cats" % "23.1.0.5", + "dev.zio" %%% "zio-test" % "2.1.21" % Test, + "dev.zio" %%% "zio-test-sbt" % "2.1.21" % Test + ) + ) + .jsSettings( + Test / scalaJSLinkerConfig ~= (_.withModuleKind(ModuleKind.CommonJSModule)) + ) + .dependsOn(connector % "test->compile") + lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) .crossType(CrossType.Full) .in(file("tests")) @@ -186,8 +221,8 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) ) .defaultSettings .jvmSettings( - Test / fork := true, - libraryDependencies += mysql % Test + Test / fork := true, + libraryDependencies += "com.mysql" % "mysql-connector-j" % "8.4.0" % Test ) .jvmConfigure(_ dependsOn jdbcConnector.jvm) .jsSettings( @@ -212,13 +247,14 @@ lazy val benchmark = (project in file("benchmark")) .settings(Compile / javacOptions ++= Seq("--release", java21)) .settings( libraryDependencies ++= Seq( - scala3Compiler, - mysql, - doobie, - slick + "org.scala-lang" %% "scala3-compiler" % scala3, + "com.mysql" % "mysql-connector-j" % "8.4.0", + "org.tpolecat" %% "doobie-core" % "1.0.0-RC10", + "com.typesafe.slick" %% "slick" % "3.6.1", + "com.zaxxer" % "HikariCP" % "7.0.2" ) ) - .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm, hikari) + .dependsOn(jdbcConnector.jvm, connector.jvm, queryBuilder.jvm) .enablePlugins(JmhPlugin, AutomateHeaderPlugin, NoPublishPlugin) lazy val http4sExample = crossProject(JVMPlatform) @@ -242,8 +278,8 @@ lazy val hikariCPExample = crossProject(JVMPlatform) .example("hikariCP", "HikariCP example project") .settings( libraryDependencies ++= Seq( - hikariCP, - mysql + "com.zaxxer" % "HikariCP" % "7.0.2", + "com.mysql" % "mysql-connector-j" % "8.4.0" ) ) .dependsOn(jdbcConnector, dsl) @@ -268,6 +304,32 @@ lazy val otelExample = crossProject(JVMPlatform) ) .dependsOn(connector, dsl) +lazy val zioExample = crossProject(JVMPlatform) + .crossType(CrossType.Pure) + .withoutSuffixFor(JVMPlatform) + .example("zio", "ZIO example project") + .settings( + libraryDependencies ++= Seq( + "dev.zio" %% "zio-http" % "3.5.1", + "dev.zio" %% "zio-json" % "0.7.44" + ) + ) + .dependsOn(connector, dsl, zioInterop) + +lazy val awsIamAuthExample = crossProject(JVMPlatform) + .crossType(CrossType.Pure) + .withoutSuffixFor(JVMPlatform) + .example("aws-iam-auth", "Aws Iam Authentication example project") + .settings( + libraryDependencies ++= Seq( + "org.http4s" %% "http4s-dsl" % "0.23.33", + "org.http4s" %% "http4s-ember-server" % "0.23.33", + "org.http4s" %% "http4s-circe" % "0.23.33", + "io.circe" %% "circe-generic" % "0.14.10" + ) + ) + .dependsOn(connector, awsAuthenticationPlugin, dsl) + lazy val docs = (project in file("docs")) .settings( description := "Documentation for ldbc", @@ -276,7 +338,7 @@ lazy val docs = (project in file("docs")) mdocVariables ++= Map( "ORGANIZATION" -> organization.value, "SCALA_VERSION" -> scalaVersion.value, - "MYSQL_VERSION" -> mysqlVersion + "MYSQL_VERSION" -> "8.4.0" ), laikaTheme := LaikaSettings.helium.value, // Modify tlSite task to run the LLM docs script after the site is generated @@ -368,7 +430,9 @@ lazy val mcpDocumentServer = crossProject(JSPlatform) lazy val examples = Seq( http4sExample, hikariCPExample, - otelExample + otelExample, + zioExample, + awsIamAuthExample ) lazy val ldbc = tlCrossRootProject @@ -384,11 +448,13 @@ lazy val ldbc = tlCrossRootProject queryBuilder, schema, codegen, + zioInterop, + authenticationPlugin, + awsAuthenticationPlugin, plugin, tests, docs, benchmark, - hikari, mcpDocumentServer ) .aggregate(examples *) diff --git a/codecov.yml b/codecov.yml index 519985f37..cd19d211e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -16,12 +16,11 @@ coverage: - "module/ldbc-sql" - "module/ldbc-core" - "module/jdbc-connector" - - "module/ldbc-schemaspy" - "tests/shared/src/main/scala/ldbc/tests/model" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/CallableStatementImpl.scala" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/ServerPreparedStatement.scala" - "module/ldbc-connector/jvm/src/main/scala/ldbc/connector/SSLPlatform.scala" - "module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Initial.scala" - "module/ldbc-schema/src/main/scala/ldbc/schema/Character.scala" - - "module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala" - "module/ldbc-dsl/src/main/scala/ldbc/dsl/ResultSetConsumer.scala" + - "module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala" diff --git a/docker-compose.yml b/docker-compose.yml index b94143052..421d2cbff 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,6 +18,22 @@ services: timeout: 20s retries: 10 + localstack: + image: localstack/localstack:latest + platform: linux/amd64 + container_name: ldbc-localstack + ports: + - "4566:4566" + - "4571:4571" + environment: + - SERVICES=iam,sts + - DEBUG=1 + - DOCKER_HOST=unix:///var/run/docker.sock + - AWS_DEFAULT_REGION=ap-northeast-1 + volumes: + - ./localstack:/etc/localstack/init/ready.d + - /var/run/docker.sock:/var/run/docker.sock + verdaccio: image: verdaccio/verdaccio:nightly-master container_name: verdaccio diff --git a/docs/src/main/mdoc/README.md b/docs/src/main/mdoc/README.md index 406fe091b..f7dab8477 100644 --- a/docs/src/main/mdoc/README.md +++ b/docs/src/main/mdoc/README.md @@ -39,17 +39,17 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | Module / Platform | JVM | Scala Native | Scala.js | Scaladoc | |----------------------|:---:|:------------:|:--------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------| -| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | -| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | -| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | -| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | -| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | -| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | -| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | -| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | -| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | -| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | +| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | +| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | +| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | +| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | +| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | +| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | +| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | +| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | +| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Quick Start diff --git a/docs/src/main/mdoc/en/index.md b/docs/src/main/mdoc/en/index.md index e03b833de..e8c7feb85 100644 --- a/docs/src/main/mdoc/en/index.md +++ b/docs/src/main/mdoc/en/index.md @@ -66,6 +66,5 @@ libraryDependencies ++= Seq( - Geometry data type support - CHECK constraint support - Support for databases other than MySQL -- ZIO module support - Test kit - etc... diff --git a/docs/src/main/mdoc/en/migration-notes.md b/docs/src/main/mdoc/en/migration-notes.md index 380e8dc25..1a0c96dae 100644 --- a/docs/src/main/mdoc/en/migration-notes.md +++ b/docs/src/main/mdoc/en/migration-notes.md @@ -32,7 +32,6 @@ | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## 🎯 Major Changes diff --git a/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md b/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md index 7bee3a856..646369374 100644 --- a/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md +++ b/docs/src/main/mdoc/en/qa/How-to-use-with-ZIO.md @@ -5,36 +5,26 @@ laika.metadata.language = ja # Q: How to use with ZIO? -## A: For use with ZIO, use `zio-interop-cats`. +## A: For use with ZIO, use `ldbc-zio-interop`. ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` The following is sample code for using ldbc with ZIO. ```scala 3 mdoc -import java.util.UUID - -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.Network - import zio.* -import zio.interop.catz.* -object Main extends ZIOAppDefault: +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given Hashing[Task] = Hashing.forSync[Task] - given Network[Task] = Network.forAsync[Task] +object Main extends ZIOAppDefault: - private def datasource = + private val datasource = MySQLDataSource - .build[Task]("127.0.0.1", 13306, "ldbc") + .build[Task]("127.0.0.1", 3306, "ldbc") .setPassword("password") .setDatabase("world") .setSSL(SSL.Trusted) @@ -51,7 +41,7 @@ object Main extends ZIOAppDefault: } ``` -### パフォーマンス +### Performance Performance results from the Cats Effect to ZIO conversion are shown below. diff --git a/docs/src/main/mdoc/en/reference/Connector.md b/docs/src/main/mdoc/en/reference/Connector.md index e5f721618..9d4d958a9 100644 --- a/docs/src/main/mdoc/en/reference/Connector.md +++ b/docs/src/main/mdoc/en/reference/Connector.md @@ -115,9 +115,17 @@ ldbc currently supports the following authentication plugins: - Native Pluggable Authentication - SHA-256 Pluggable Authentication - SHA-2 Pluggable Authentication Cache +- Cleartext Pluggable Authentication *Note: Native Pluggable Authentication and SHA-256 Pluggable Authentication are deprecated plugins from MySQL 8.x. It is recommended to use the SHA-2 Pluggable Authentication Cache unless there is a specific reason.* +@:callout(warning) +**Important Security Precautions:** +MySQL cleartext pluggable authentication is an authentication method that sends passwords in plain text to the server. When using this authentication plugin, you must enable SSL/TLS connections. Using it without SSL/TLS connections is extremely dangerous from a security standpoint, as passwords are transmitted over the network in plain text. + +This authentication plugin is primarily used for integration with AWS IAM authentication and other external authentication systems. For instructions on configuring SSL/TLS connections, refer to the [SSL configuration section of the tutorial](/en/tutorial/Connection.md#connection-with-ssl-configuration). +@:@ + You do not need to be aware of authentication plugins in the ldbc application code. Users should create users with the desired authentication plugin on the MySQL database and use those users to attempt connections to MySQL in the ldbc application code. ldbc internally determines the authentication plugin and connects to MySQL using the appropriate authentication plugin. diff --git a/docs/src/main/mdoc/en/tutorial/Connection.md b/docs/src/main/mdoc/en/tutorial/Connection.md index 5796aec32..0d3321429 100644 --- a/docs/src/main/mdoc/en/tutorial/Connection.md +++ b/docs/src/main/mdoc/en/tutorial/Connection.md @@ -148,6 +148,8 @@ You can add SSL configuration to establish a secure connection: ※ Note that Trusted accepts all certificates. This is a setting for development environments. +※ For security reasons, SSL/TLS connections are required for certain authentication plugins, such as MySQL cleartext pluggable authentication. For details, see the [authentication section of the reference](/en/reference/Connector.md#authentication). + ```scala import cats.effect.IO import ldbc.connector.* diff --git a/docs/src/main/mdoc/index.md b/docs/src/main/mdoc/index.md index 637cb365f..e0286b976 100644 --- a/docs/src/main/mdoc/index.md +++ b/docs/src/main/mdoc/index.md @@ -38,17 +38,17 @@ ldbc is available on the JVM, Scala.js, and ScalaNative | Module / Platform | JVM | Scala Native | Scala.js | Scaladoc | |----------------------|:---:|:------------:|:--------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------| -| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | -| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | -| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | -| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | -| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | -| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | -| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | -| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | -| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | -| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.1-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-sql` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-sql_3) | +| `ldbc-core` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-core_3) | +| `ldbc-connector` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-connector_3) | +| `jdbc-connector` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/jdbc-connector_3) | +| `ldbc-dsl` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3) | +| `ldbc-statement` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-statement_3) | +| `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | +| `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | +| `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | +| `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | +| `ldbc-zio-interop` | ✅ | ❌ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.5.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-zio-interop_3) | ## Quick Start diff --git a/docs/src/main/mdoc/ja/index.md b/docs/src/main/mdoc/ja/index.md index eb14b520e..dc1496fd0 100644 --- a/docs/src/main/mdoc/ja/index.md +++ b/docs/src/main/mdoc/ja/index.md @@ -66,6 +66,5 @@ libraryDependencies ++= Seq( - Geometryデータタイプのサポート - CHECK制約のサポート - MySQL以外のデータベースサポート -- ZIOモジュールのサポート - テストキット - etc... diff --git a/docs/src/main/mdoc/ja/migration-notes.md b/docs/src/main/mdoc/ja/migration-notes.md index 8624b298d..7c14e0477 100644 --- a/docs/src/main/mdoc/ja/migration-notes.md +++ b/docs/src/main/mdoc/ja/migration-notes.md @@ -32,7 +32,6 @@ | `ldbc-query-builder` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-query-builder_3) | | `ldbc-schema` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-schema_3) | | `ldbc-codegen` | ✅ | ✅ | ✅ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-codegen_3) | -| `ldbc-hikari` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-hikari_3) | | `ldbc-plugin` | ✅ | ❌ | ❌ | [![Scaladoc](https://img.shields.io/badge/javadoc-0.4.0-brightgreen.svg?label=Scaladoc)](https://javadoc.io/doc/io.github.takapi327/ldbc-plugin_2.12_1.0) | ## 🎯 主要な変更点 diff --git a/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md b/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md index abafba1d1..d15366141 100644 --- a/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md +++ b/docs/src/main/mdoc/ja/qa/How-to-use-with-ZIO.md @@ -5,39 +5,30 @@ laika.metadata.language = ja # Q: ZIOで使用する方法は? -## A: ZIOで使用する場合、`zio-interop-cats`を使用します。 +## A: ZIOで使用する場合、`ldbc-zio-interop`を使用します。 ```scala -libraryDependencies += "dev.zio" %% "zio-interop-cats" % "" +libraryDependencies += "io.github.takapi327" %% "ldbc-zio-interop" % "latest" ``` 以下は、ZIOを使用してldbcを利用するためのサンプルコードです。 ```scala 3 mdoc -import java.util.UUID - -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.Network - import zio.* -import zio.interop.catz.* + +import ldbc.zio.interop.* +import ldbc.connector.* +import ldbc.dsl.* object Main extends ZIOAppDefault: - given cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] - given UUIDGen[Task] with - override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) - given Hashing[Task] = Hashing.forSync[Task] - given Network[Task] = Network.forAsync[Task] - - private def datasource = MySQLDataSource - .build[Task]("127.0.0.1", 13306, "ldbc") - .setPassword("password") - .setDatabase("world") - .setSSL(SSL.Trusted) - + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 3306, "ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + private val connector = Connector.fromDataSource(datasource) override def run = diff --git a/docs/src/main/mdoc/ja/reference/Connector.md b/docs/src/main/mdoc/ja/reference/Connector.md index 510463307..c15a01464 100644 --- a/docs/src/main/mdoc/ja/reference/Connector.md +++ b/docs/src/main/mdoc/ja/reference/Connector.md @@ -115,9 +115,17 @@ ldbcは現時点で以下の認証プラグインをサポートしています - ネイティブプラガブル認証 - SHA-256 プラガブル認証 - SHA-2 プラガブル認証のキャッシュ +- クリアテキストプラガブル認証 ※ ネイティブプラガブル認証とSHA-256 プラガブル認証はMySQL 8.xから非推奨となったプラグインです。特段理由がない場合はSHA-2 プラガブル認証のキャッシュを使用することを推奨します。 +@:callout(warning) +**重要なセキュリティ注意事項:** +MySQL クリアテキストプラガブル認証は、パスワードを平文でサーバーに送信する認証方式です。この認証プラグインを使用する場合は、必ずSSL/TLS接続を有効にしてください。SSL/TLS接続なしでの使用は、パスワードがネットワーク上で平文として送信されるため、セキュリティ上非常に危険です。 + +この認証プラグインは主にAWS IAM認証やその他の外部認証システムとの統合で使用されます。SSL/TLS接続の設定方法については、[チュートリアルのSSL設定セクション](/ja/tutorial/Connection.md#ssl設定を使用したコネクション)を参照してください。 +@:@ + ldbcのアプリケーションコード上で認証プラグインを意識する必要はありません。ユーザーはMySQLのデータベース上で使用したい認証プラグインで作成されたユーザーを作成し、ldbcのアプリケーションコード上ではそのユーザーを使用してMySQLへの接続を試みるだけで問題ありません。 ldbcが内部で認証プラグインを判断し、適切な認証プラグインを使用してMySQLへの接続を行います。 diff --git a/docs/src/main/mdoc/ja/tutorial/Connection.md b/docs/src/main/mdoc/ja/tutorial/Connection.md index ab15892ae..9c3903b9a 100644 --- a/docs/src/main/mdoc/ja/tutorial/Connection.md +++ b/docs/src/main/mdoc/ja/tutorial/Connection.md @@ -146,6 +146,8 @@ val program = datasource.getConnection.use { connection => ※ Trustedは全ての証明書を受け入れることに注意してください。これは開発環境向けの設定です。 +※ MySQL クリアテキストプラガブル認証などの一部の認証プラグインでは、セキュリティ上の理由からSSL/TLS接続が必須です。詳細は[リファレンスの認証セクション](/ja/reference/Connector.md#認証)を参照してください。 + ```scala import cats.effect.IO import ldbc.connector.* diff --git a/docs/src/main/mdoc/older-versions.md b/docs/src/main/mdoc/older-versions.md index 49002545f..7ba431cbe 100644 --- a/docs/src/main/mdoc/older-versions.md +++ b/docs/src/main/mdoc/older-versions.md @@ -17,13 +17,16 @@ and now also more API stability than those older releases. | Release | Date | Scala Versions | sbt plugin | Scala.js | Scala Native | Doc Sources | API (Scaladoc) | |---------|----------|----------------|------------|----------|--------------|-----------------|-----------------| +| 0.4 | Oct 2025 | 3.3.x | 1.x | | | [Browse][doc04] | [Browse][api04] | | 0.3 | May 2025 | 3.3.x | 1.x | | | [Browse][doc03] | [Browse][api03] | | 0.2 | Mar 2024 | 3.3.x | 1.x | | | [Browse][doc02] | [Browse][api02] | | 0.1 | Dec 2023 | 3.3.x | 1.x | | | | [Browse][api01] | [doc02]: https://takapi327.github.io/ldbc/0.2/ [doc03]: https://takapi327.github.io/ldbc/0.3/ +[doc04]: https://takapi327.github.io/ldbc/0.4/ +[api04]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.4.0/index.html [api03]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.3.3/index.html [api02]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.2.1/index.html [api01]: https://javadoc.io/doc/io.github.takapi327/ldbc-dsl_3/0.1.1/index.html diff --git a/docs/src/main/scala/01-Program.scala b/docs/src/main/scala/01-Program.scala index cf607f071..f26c71b02 100644 --- a/docs/src/main/scala/01-Program.scala +++ b/docs/src/main/scala/01-Program.scala @@ -18,17 +18,18 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/02-Program.scala b/docs/src/main/scala/02-Program.scala index 20f309116..75528f552 100644 --- a/docs/src/main/scala/02-Program.scala +++ b/docs/src/main/scala/02-Program.scala @@ -18,17 +18,18 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // Some(2) // #run diff --git a/docs/src/main/scala/03-Program.scala b/docs/src/main/scala/03-Program.scala index 45ac92f9c..6584829c8 100644 --- a/docs/src/main/scala/03-Program.scala +++ b/docs/src/main/scala/03-Program.scala @@ -23,17 +23,18 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.readOnly(conn).map(println(_)) - } + program + .readOnly(connector) + .map(println(_)) .unsafeRunSync() // (List(1), Some(2), 3) // #run diff --git a/docs/src/main/scala/04-Program.scala b/docs/src/main/scala/04-Program.scala index 7e9261fee..5c26ccbe4 100644 --- a/docs/src/main/scala/04-Program.scala +++ b/docs/src/main/scala/04-Program.scala @@ -19,17 +19,18 @@ import ldbc.connector.* // #program // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) + + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - program.commit(conn).map(println(_)) - } + program + .commit(connector) + .map(println(_)) .unsafeRunSync() // 1 // #run diff --git a/docs/src/main/scala/05-Program.scala b/docs/src/main/scala/05-Program.scala index 0e3494af8..9234daad5 100644 --- a/docs/src/main/scala/05-Program.scala +++ b/docs/src/main/scala/05-Program.scala @@ -35,13 +35,12 @@ import ldbc.connector.* val program2: DBIO[(String, String, Status)] = sql"SELECT name, email, status FROM user WHERE id = 1".query[(String, String, Status)].unsafe - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") .setSSL(SSL.Trusted) - connection - .use { conn => - program1.commit(conn) *> program2.readOnly(conn).map(println(_)) - } + def connector = Connector.fromDataSource(dataSource) + + (program1.commit(connector) *> program2.readOnly(connector).map(println(_))) .unsafeRunSync() diff --git a/docs/src/main/scala/0X-Cleanup.scala b/docs/src/main/scala/0X-Cleanup.scala index a6f3be7aa..09c879558 100644 --- a/docs/src/main/scala/0X-Cleanup.scala +++ b/docs/src/main/scala/0X-Cleanup.scala @@ -19,15 +19,15 @@ import ldbc.connector.* // #cleanupDatabase // #connection - def connection = ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") + val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") .setPassword("password") + def connector = Connector.fromDataSource(dataSource) // #connection // #run - connection - .use { conn => - dropDatabase.commit(conn).as(println("Database dropped")) - } + dropDatabase + .commit(connector) + .as(println("Database dropped")) .unsafeRunSync() // #run diff --git a/examples/aws-iam-auth/src/main/scala/Main.scala b/examples/aws-iam-auth/src/main/scala/Main.scala new file mode 100644 index 000000000..94fc9056d --- /dev/null +++ b/examples/aws-iam-auth/src/main/scala/Main.scala @@ -0,0 +1,97 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +import com.comcast.ip4s.* + +import cats.effect.* +import cats.effect.std.{ Console, Env } + +import io.circe.* +import io.circe.syntax.* + +import ldbc.dsl.* + +import ldbc.connector.* + +import ldbc.amazon.plugin.AwsIamAuthenticationPlugin +import ldbc.logging.* + +import org.http4s.* +import org.http4s.circe.CirceEntityEncoder.* +import org.http4s.dsl.io.* +import org.http4s.ember.server.EmberServerBuilder + +case class City( + id: Int, + name: String, + countryCode: String, + district: String, + population: Int +) + +object City: + + given Encoder[City] = Encoder.derived[City] + +object Main extends ResourceApp.Forever: + + private val logHandler: LogHandler[IO] = { + case LogEvent.Success(sql, args) => + Console[IO].println( + s"""Successful Statement Execution: + | $sql + | + | arguments = [${ args.mkString(",") }] + |""".stripMargin + ) + case LogEvent.ProcessingFailure(sql, args, failure) => + Console[IO].errorln( + s"""Failed ResultSet Processing: + | $sql + | + | arguments = [${ args.mkString(",") }] + |""".stripMargin + ) >> Console[IO].printStackTrace(failure) + case LogEvent.ExecFailure(sql, args, failure) => + Console[IO].errorln( + s"""Failed Statement Execution: + | $sql + | + | arguments = [${ args.mkString(",") }] + |""".stripMargin + ) >> Console[IO].printStackTrace(failure) + } + + private def routes(connector: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { + case GET -> Root / "healthcheck" => Ok("Healthcheck") + case GET -> Root / "cities" => + for + cities <- sql"SELECT * FROM city".query[City].to[List].readOnly(connector) + result <- Ok(cities.asJson) + yield result + } + + override def run(args: List[String]): Resource[IO, Unit] = + for + hostname <- Resource.eval(Env[IO].get("AURORA_HOST").flatMap { + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_HOST is not set")) + }) + username <- Resource.eval(Env[IO].get("AURORA_USER").flatMap { + case Some(v) => IO.pure(v) + case None => IO.raiseError(new RuntimeException("AURORA_USER is not set")) + }) + config = MySQLConfig.default.setHost(hostname).setUser(username).setDatabase("world").setSSL(SSL.Trusted) + plugin = AwsIamAuthenticationPlugin.default[IO]("ap-northeast-1", hostname, username) + datasource <- MySQLDataSource.pooling[IO](config, plugins = List(plugin)) + connector = Connector.fromDataSource(datasource, Some(logHandler)) + _ <- EmberServerBuilder + .default[IO] + .withHost(host"0.0.0.0") + .withPort(port"9000") + .withHttpApp(routes(connector).orNotFound) + .build + yield () diff --git a/examples/http4s/src/main/scala/Main.scala b/examples/http4s/src/main/scala/Main.scala index 40712ea4c..2e9412298 100644 --- a/examples/http4s/src/main/scala/Main.scala +++ b/examples/http4s/src/main/scala/Main.scala @@ -48,24 +48,24 @@ object Main extends ResourceApp.Forever: private val cityTable = TableQuery[CityTable] - private def provider = - ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc", "password", "world") - .setSSL(SSL.Trusted) - - private def routes(conn: Connector[IO]): HttpRoutes[IO] = HttpRoutes.of[IO] { + private val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + private val connector = Connector.fromDataSource(dataSource) + + private def routes: HttpRoutes[IO] = HttpRoutes.of[IO] { case GET -> Root / "cities" => for - cities <- cityTable.selectAll.query.to[List].readOnly(conn) + cities <- cityTable.selectAll.query.to[List].readOnly(connector) result <- Ok(cities.asJson) yield result } override def run(args: List[String]): Resource[IO, Unit] = - for - conn <- provider.createConnector() - _ <- EmberServerBuilder - .default[IO] - .withHttpApp(routes(conn).orNotFound) - .build + for _ <- EmberServerBuilder + .default[IO] + .withHttpApp(routes.orNotFound) + .build yield () diff --git a/examples/otel/src/main/scala/Main.scala b/examples/otel/src/main/scala/Main.scala index 978f7922c..174724466 100644 --- a/examples/otel/src/main/scala/Main.scala +++ b/examples/otel/src/main/scala/Main.scala @@ -18,18 +18,18 @@ object Main extends IOApp.Simple: private val serviceName = "ldbc-otel-example" + private val dataSource = MySQLDataSource + .build[IO]("127.0.0.1", 13307, "ldbc") + .setPassword("password") + .setDatabase("world") + private def resource: Resource[IO, Connector[IO]] = for otel <- Resource .eval(IO.delay(GlobalOpenTelemetry.get)) .evalMap(OtelJava.fromJOpenTelemetry[IO]) - tracer <- Resource.eval(otel.tracerProvider.get(serviceName)) - connection <- ConnectionProvider - .default[IO]("127.0.0.1", 13307, "ldbc", "password", "world") - .setSSL(SSL.Trusted) - .setTracer(tracer) - .createConnector() - yield connection + tracer <- Resource.eval(otel.tracerProvider.get(serviceName)) + yield Connector.fromDataSource(dataSource.setTracer(tracer)) override def run: IO[Unit] = resource.use { conn => diff --git a/examples/zio/src/main/scala/Main.scala b/examples/zio/src/main/scala/Main.scala new file mode 100644 index 000000000..b00613760 --- /dev/null +++ b/examples/zio/src/main/scala/Main.scala @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +import ldbc.dsl.* + +import ldbc.connector.* + +import ldbc.zio.interop.* + +import zio.* +import zio.http.* +import zio.interop.catz.* +import zio.json.* + +object Main extends ZIOAppDefault: + + private val poolConfig = MySQLConfig.default + .setHost("localhost") + .setPort(13306) + .setUser("ldbc") + .setPassword("password") + .setDatabase("world") + .setSSL(SSL.Trusted) + .setMinConnections(5) + .setMaxConnections(10) + + private val connectorLayer: ZLayer[Any, Throwable, Connector[Task]] = + ZLayer.scoped { + MySQLDataSource.pooling[Task](poolConfig).map(ds => Connector.fromDataSource[Task](ds)).toScopedZIO + } + + private val routes = Routes( + Method.GET / Root -> handler(Response.text("Hello, World!")), + Method.GET / Root / "countries" -> handler { + for + connector <- ZIO.service[Connector[Task]] + countries <- sql"SELECT Name FROM country".query[String].to[List].readOnly(connector) + yield Response.json(countries.toJson) + }.catchAll { error => + handler(Response.json(Map("error" -> error.getMessage).toJson)) + } + ) + + override def run = + Server + .serve(routes) + .provide( + Server.default, + connectorLayer + ) diff --git a/localstack/init.sh b/localstack/init.sh new file mode 100644 index 000000000..7336f7e73 --- /dev/null +++ b/localstack/init.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Wait for LocalStack to be ready +echo "Waiting for LocalStack to be ready..." +sleep 5 + +# Set AWS credentials for LocalStack +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test + +aws --endpoint-url=http://localhost:4566 iam create-role \ + --role-name localstack-role \ + --assume-role-policy-document '{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":"arn:aws:iam::000000000000:root"},"Action":"sts:AssumeRole"}]}' diff --git a/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala b/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala deleted file mode 100644 index 1b16dff9c..000000000 --- a/module/jdbc-connector/src/main/scala/jdbc/connector/ConnectionProvider.scala +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package jdbc.connector - -import java.sql.DriverManager - -import javax.sql.DataSource - -import scala.concurrent.ExecutionContext - -import cats.effect.* - -import ldbc.sql.Connection - -import ldbc.{ Connector, Provider } -import ldbc.logging.LogHandler - -@deprecated( - "Connection creation using ConnectionProvider is now deprecated. Please use jdbc.connector.MySQLDataSource from now on. ConnectionProvider will be removed in version 0.5.x.", - "ldbc 0.4.0" -) -object ConnectionProvider: - - private case class DataSourceProvider[F[_]]( - dataSource: DataSource, - connectEC: ExecutionContext, - logHandler: Option[LogHandler[F]] - )(using ev: Async[F]) - extends Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource - .fromAutoCloseable(ev.evalOn(ev.delay(dataSource.getConnection()), connectEC)) - .map(conn => ConnectionImpl(conn)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - private case class JavaConnectionProvider[F[_]: Sync]( - connection: java.sql.Connection, - logHandler: Option[LogHandler[F]] - ) extends Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource.pure(ConnectionImpl(connection)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - class DriverProvider[F[_]](using ev: Async[F]): - - private def create( - driver: String, - conn: () => java.sql.Connection, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - new Provider[F]: - override def createConnection(): Resource[F, Connection[F]] = - Resource - .fromAutoCloseable(ev.blocking { - Class.forName(driver) - conn() - }) - .map(conn => ConnectionImpl(conn)) - - override def use[A](f: Connector[F] => F[A]): F[A] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url), logHandler) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param user - * database username - * @param password - * database password - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - user: String, - password: String, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url, user, password), logHandler) - - /** Construct a new `Provider` that uses the JDBC `DriverManager` to allocate connections. - * - * @param driver - * the class name of the JDBC driver, like "com.mysql.cj.jdbc.MySQLDriver" - * @param url - * a connection URL, specific to your driver - * @param info - * a `Properties` containing connection information (see `DriverManager.getConnection`) - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def apply( - driver: String, - url: String, - info: java.util.Properties, - logHandler: Option[LogHandler[F]] - ): Provider[F] = - create(driver, () => DriverManager.getConnection(url, info), logHandler) - - /** - * Construct a constructor of `Provider[F]` for some `D <: DataSource` When calling this constructor you - * should explicitly supply the effect type M e.g. ConnectionProvider.fromDataSource[IO](myDataSource, connectEC) - * - * @param dataSource - * A data source that manages connection information to MySQL. - * @param connectEC - * Execution context dedicated to database connection. - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def fromDataSource[F[_]: Async]( - dataSource: DataSource, - connectEC: ExecutionContext, - logHandler: Option[LogHandler[F]] = None - ): Provider[F] = DataSourceProvider[F](dataSource, connectEC, logHandler) - - /** - * Construct a `Provider` that wraps an existing `Connection`. Closing the connection is the responsibility of - * the caller. - * - * @param connection - * a raw JDBC `Connection` to wrap - * @param logHandler - * Handler for outputting logs of process execution using connections. - */ - def fromConnection[F[_]: Sync]( - connection: java.sql.Connection, - logHandler: Option[LogHandler[F]] = None - ): Provider[F] = JavaConnectionProvider[F](connection, logHandler) - - /** Module of constructors for `Provider` that use the JDBC `DriverManager` to allocate connections. Note that - * `DriverManager` is unbounded and will happily allocate new connections until server resources are exhausted. It - * is usually preferable to use `DataSourceTransactor` with an underlying bounded connection pool. Blocking operations on `DriverProvider` are - * executed on an unbounded cached daemon thread pool by default, so you are also at risk of exhausting system - * threads. TL;DR this is fine for console apps but don't use it for a web application. - */ - def fromDriverManager[F[_]: Async]: DriverProvider[F] = new DriverProvider[F] diff --git a/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala b/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala index 95e26409e..2ea26a1e5 100644 --- a/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala +++ b/module/jdbc-connector/src/main/scala/jdbc/connector/package.scala @@ -9,4 +9,3 @@ package jdbc package object connector: export ldbc.Connector export ldbc.DataSource - export ldbc.Provider diff --git a/module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala similarity index 75% rename from module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala rename to module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index af1d261e5..fdba42a91 100644 --- a/module/ldbc-connector/js/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ b/module/ldbc-authentication-plugin/js/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -4,16 +4,25 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -package ldbc.connector.authenticator +package ldbc.authentication.plugin + import java.nio.charset.StandardCharsets import scala.scalajs.js import scala.scalajs.js.typedarray.Uint8Array import scodec.bits.ByteVector -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => + +trait EncryptPasswordPlugin: + private val crypto = js.Dynamic.global.require("crypto") + def transformation: String + + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) val mysqlScrambleBuff = xorString(input, scramble, input.length) @@ -28,4 +37,3 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => ByteVector(input).toUint8Array ) ByteVector.view(encrypted.asInstanceOf[Uint8Array]).toArray -} diff --git a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala similarity index 82% rename from module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala rename to module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index cc9df619d..ec352e840 100644 --- a/module/ldbc-connector/jvm/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ b/module/ldbc-authentication-plugin/jvm/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -4,7 +4,7 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -package ldbc.connector.authenticator +package ldbc.authentication.plugin import java.nio.charset.StandardCharsets import java.security.interfaces.RSAPublicKey @@ -15,7 +15,10 @@ import java.util.Base64 import javax.crypto.Cipher -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => +trait EncryptPasswordPlugin: + + def transformation: String + def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) val mysqlScrambleBuff = xorString(input, scramble, input.length) @@ -24,6 +27,10 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => decodeRSAPublicKey(publicKeyString) ) + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + private def encryptWithRSAPublicKey(input: Array[Byte], key: PublicKey): Array[Byte] = val cipher = Cipher.getInstance(transformation) cipher.init(Cipher.ENCRYPT_MODE, key) @@ -36,4 +43,3 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => val spec = new X509EncodedKeySpec(certificateData) val kf = KeyFactory.getInstance("RSA") kf.generatePublic(spec).asInstanceOf[RSAPublicKey] -} diff --git a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala similarity index 87% rename from module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala rename to module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala index 555f8310d..65b42b368 100644 --- a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPluginPlatform.scala +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/EncryptPasswordPlugin.scala @@ -4,15 +4,19 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -package ldbc.connector.authenticator +package ldbc.authentication.plugin + import java.nio.charset.StandardCharsets import scala.scalanative.unsafe.* import scala.scalanative.unsigned.* -import ldbc.connector.authenticator.Openssl.* +import ldbc.authentication.plugin.Openssl.* + +trait EncryptPasswordPlugin: + + def transformation: String -trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => def encryptPassword(password: String, scramble: Array[Byte], publicKeyString: String): Array[Byte] = val input = if password.nonEmpty then (password + "\u0000").getBytes(StandardCharsets.UTF_8) else Array[Byte](0) val mysqlScrambleBuff = xorString(input, scramble, input.length) @@ -21,6 +25,10 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => publicKeyString ) + private def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = + val scrambleLength = scramble.length + (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray + private def encryptWithRSAPublicKey(input: Array[Byte], publicKey: String): Array[Byte] = Zone { implicit zone => val publicKeyCStr = toCString(publicKey) @@ -64,4 +72,3 @@ trait Sha256PasswordPluginPlatform[F[_]] { self: Sha256PasswordPlugin[F] => result } -} diff --git a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala similarity index 97% rename from module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala rename to module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala index 4830853ad..5a9ecd34a 100644 --- a/module/ldbc-connector/native/src/main/scala/ldbc/connector/authenticator/Openssl.scala +++ b/module/ldbc-authentication-plugin/native/src/main/scala/ldbc/authentication/plugin/Openssl.scala @@ -4,7 +4,7 @@ * For more information see LICENSE or https://opensource.org/licenses/MIT */ -package ldbc.connector.authenticator +package ldbc.authentication.plugin import scala.scalanative.unsafe.* import scala.scalanative.unsigned.* diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala new file mode 100644 index 000000000..5e6233930 --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/AuthenticationPlugin.scala @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import scodec.bits.ByteVector + +/** + * A trait representing a MySQL authentication plugin for database connections. + * + * This trait defines the contract for various authentication mechanisms supported by MySQL, + * including traditional password-based authentication (mysql_native_password) and + * modern authentication methods like mysql_clear_password for IAM authentication. + * + * Authentication plugins are used during the MySQL handshake process to validate + * client credentials and establish secure database connections. + * + * @tparam F The effect type that wraps the authentication operations + */ +trait AuthenticationPlugin[F[_]]: + + /** + * The name of the authentication plugin as recognized by the MySQL server. + * + * Common plugin names include: + * - "mysql_native_password" for traditional SHA1-based password authentication + * - "mysql_clear_password" for plaintext password transmission over SSL + * - "caching_sha2_password" for SHA256-based password authentication + * - "mysql_old_password" for legacy MySQL authentication (deprecated) + * + * @return The plugin name string that identifies this authentication method + */ + def name: PluginName + + /** + * Indicates whether this authentication plugin requires a secure (encrypted) connection. + * + * Some authentication plugins, particularly those that transmit passwords in cleartext + * (like mysql_clear_password), require SSL/TLS encryption to ensure data security. + * Traditional hashing-based plugins may optionally use encryption but don't strictly require it. + * + * @return true if SSL/TLS connection is mandatory for this plugin, false otherwise + */ + def requiresConfidentiality: Boolean + + /** + * Processes the password according to the authentication plugin's requirements. + * + * Different authentication plugins handle passwords differently: + * - mysql_native_password: Performs SHA1-based hashing with the server's scramble + * - mysql_clear_password: Returns the password as plaintext bytes (requires SSL) + * - caching_sha2_password: Performs SHA256-based hashing with salt + * + * @param password The user's password in plaintext + * @param scramble The random challenge bytes sent by the MySQL server during handshake. + * Used as salt/seed for cryptographic hashing in most authentication methods. + * May be ignored by plugins that don't use server-side challenges. + * @return The processed password data wrapped in the effect type F, ready for transmission + * to the MySQL server during authentication + */ + def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala new file mode 100644 index 000000000..fb6d8e03a --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/MysqlClearPasswordPlugin.scala @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication.plugin + +import java.nio.charset.StandardCharsets + +import scodec.bits.ByteVector + +import cats.Applicative + +trait MysqlClearPasswordPlugin[F[_]] extends AuthenticationPlugin[F]: + + override final def name: PluginName = MYSQL_CLEAR_PASSWORD + override final def requiresConfidentiality: Boolean = true + +object MysqlClearPasswordPlugin: + + private case class Impl[F[_]: Applicative]() extends MysqlClearPasswordPlugin[F]: + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + val result = + if password.isEmpty then ByteVector.empty + else ByteVector(password.getBytes(StandardCharsets.UTF_8)) + Applicative[F].pure(result) + + def apply[F[_]: Applicative](): MysqlClearPasswordPlugin[F] = Impl[F]() diff --git a/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala new file mode 100644 index 000000000..7e7c04248 --- /dev/null +++ b/module/ldbc-authentication-plugin/shared/src/main/scala/ldbc/authentication/plugin/pacakge.scala @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.authentication + +package object plugin: + + opaque type PluginName = String + + // Standard MySQL Authentication Plugins + val MYSQL_CLEAR_PASSWORD: PluginName = "mysql_clear_password" + val MYSQL_NATIVE_PASSWORD: PluginName = "mysql_native_password" + val SHA256_PASSWORD: PluginName = "sha256_password" + val CACHING_SHA2_PASSWORD: PluginName = "caching_sha2_password" + + // Legacy Authentication Plugin (deprecated) + val MYSQL_OLD_PASSWORD: PluginName = "mysql_old_password" + + // Authentication plugins for external authentication + val AUTHENTICATION_WINDOWS: PluginName = "authentication_windows" + val AUTHENTICATION_PAM: PluginName = "authentication_pam" + val AUTHENTICATION_LDAP_SIMPLE: PluginName = "authentication_ldap_simple" + val AUTHENTICATION_LDAP_SASL: PluginName = "authentication_ldap_sasl" + + // Kerberos authentication + val AUTHENTICATION_KERBEROS: PluginName = "authentication_kerberos" + + // FIDO authentication (MySQL 8.0.27+) + val AUTHENTICATION_FIDO: PluginName = "authentication_fido" + + // Multi-factor authentication (MySQL 8.0.27+) + val AUTHENTICATION_WEBAUTHN: PluginName = "authentication_webauthn" + + // No login authentication plugin + val MYSQL_NO_LOGIN: PluginName = "mysql_no_login" + + // Test plugins (for testing purposes) + val TEST_PLUGIN_SERVER: PluginName = "test_plugin_server" + val DAEMON_EXAMPLE: PluginName = "daemon_example" + + // Socket peer-credential authentication (Unix socket) + val AUTH_SOCKET: PluginName = "auth_socket" diff --git a/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..df1c98246 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/js/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.effect.* + +import fs2.io.net.* +import fs2.io.net.tls.* + +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ +final case class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +) extends BasedHttpClient[F]: + + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(servername = Some(host))) + .build + yield tlsSocket + else Network[F].client(address) diff --git a/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..0ffe3efd7 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/jvm/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,47 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import javax.net.ssl.SNIHostName + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.effect.* + +import fs2.io.net.* +import fs2.io.net.tls.* + +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ +final case class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +) extends BasedHttpClient[F]: + + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(serverNames = Some(List(new SNIHostName(host))))) + .build + yield tlsSocket + else Network[F].client(address) diff --git a/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala new file mode 100644 index 000000000..476c732d3 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/native/src/main/scala/ldbc/amazon/client/SimpleHttpClient.scala @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.effect.* + +import fs2.io.net.* +import fs2.io.net.tls.* + +/** + * Secure HTTP client that supports both HTTP and HTTPS protocols. + * + * Security Features: + * - Validates URI schemes and rejects unsupported protocols + * - Uses TLS for HTTPS connections with proper certificate validation + * - Defaults to secure ports (443 for HTTPS, 80 for HTTP) + * - Prevents credentials from being sent over cleartext connections + * + * This addresses the security vulnerability where AWS credentials + * could be sent over unencrypted HTTP connections. + */ +final case class SimpleHttpClient[F[_]: Network: Async]( + connectTimeout: Duration, + readTimeout: Duration +) extends BasedHttpClient[F]: + + override def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] = + if isSecure then + for + socket <- Network[F].client(address) + tlsContext <- Network[F].tlsContext.systemResource + tlsSocket <- tlsContext + .clientBuilder(socket) + .withParameters(TLSParameters(serverName = Some(host))) + .build + yield tlsSocket + else Network[F].client(address) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala new file mode 100644 index 000000000..ecba20f0c --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsBasicCredentials.scala @@ -0,0 +1,34 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import ldbc.amazon.identity.AwsCredentials + +/** + * Basic AWS credentials implementation that contains an access key ID and secret access key. + * + * This implementation is used for standard AWS access credentials that consist of a public + * access key ID and a private secret access key. Basic credentials do not include session + * tokens and are typically used for long-term access. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param validateCredentials Whether these credentials should be validated before use + * @param providerName Optional name of the credentials provider that created these credentials + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for these credentials, if applicable + */ +final case class AwsBasicCredentials( + accessKeyId: String, + secretAccessKey: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant] +) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala new file mode 100644 index 000000000..679ef3dee --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/AwsSessionCredentials.scala @@ -0,0 +1,37 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import ldbc.amazon.identity.AwsCredentials + +/** + * AWS session credentials implementation that includes a session token. + * + * Session credentials are temporary credentials that include an access key ID, + * secret access key, and a session token. These are typically obtained from + * AWS Security Token Service (STS) and have a limited lifetime. They are commonly + * used for temporary access, role-based access, or federated access scenarios. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param sessionToken The session token that must be included with requests + * @param validateCredentials Whether these credentials should be validated before use + * @param providerName Optional name of the credentials provider that created these credentials + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for these credentials + */ +final case class AwsSessionCredentials( + accessKeyId: String, + secretAccessKey: String, + sessionToken: String, + validateCredentials: Boolean, + providerName: Option[String], + accountId: Option[String], + expirationTime: Option[Instant] +) extends AwsCredentials diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala new file mode 100644 index 000000000..22eb519c5 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProvider.scala @@ -0,0 +1,463 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import cats.syntax.all.* + +import cats.effect.std.Env +import cats.effect.Concurrent + +import fs2.io.file.{ Files, Path } + +import ldbc.amazon.client.HttpClient +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SimpleJsonParser + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from AWS Container Credential Provider. + * + * This provider is used in containerized environments such as: + * - Amazon ECS with IAM Task Roles + * - Amazon EKS with IAM Roles for Service Accounts (IRSA) + * - Amazon EKS with Pod Identity + * + * The provider reads configuration from environment variables: + * - AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: Relative URI for ECS metadata endpoint + * - AWS_CONTAINER_CREDENTIALS_FULL_URI: Full URI for credential endpoint + * - AWS_CONTAINER_AUTHORIZATION_TOKEN: Authorization token for requests + * - AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE: Path to file containing authorization token + * + * Authentication flow: + * 1. Check environment variables for endpoint configuration + * 2. Load authorization token from direct env var or file + * 3. Make HTTP GET request to credential endpoint with Authorization header + * 4. Parse JSON response to extract temporary credentials + * + * ECS endpoint: http://169.254.170.2/ + * EKS Pod Identity endpoint: http://169.254.170.23/v1/credentials + * + * @param httpClient HTTP client for making credential requests + * @tparam F The effect type + */ +final class ContainerCredentialsProvider[F[_]: Files: Env: Concurrent]( + httpClient: HttpClient[F] +) extends AwsCredentialsProvider[F]: + + /** + * Resolves AWS credentials from the container credential provider endpoint. + * + * This method implements the core logic for retrieving temporary AWS credentials + * from container-based credential providers such as ECS Task Roles or EKS Pod Identity. + * + * The resolution process: + * 1. Load container credentials configuration from environment variables + * 2. Validate that required environment variables are present + * 3. Make HTTP request to the credential endpoint + * 4. Parse and return the temporary credentials + * + * @return F[AwsCredentials] The resolved AWS credentials with access key, secret key, and session token + * @throws SdkClientException if container credentials configuration is missing or credential retrieval fails + */ + override def resolveCredentials(): F[AwsCredentials] = + for + config <- loadContainerCredentialsConfig() + credentials <- config match { + case None => + Concurrent[F].raiseError( + new SdkClientException( + "Unable to load container credentials. " + + "Environment variables AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or " + + "AWS_CONTAINER_CREDENTIALS_FULL_URI are not set." + ) + ) + case Some(containerConfig) => + fetchCredentialsFromEndpoint(containerConfig) + } + yield credentials + + /** + * Loads container credentials configuration from environment variables. + * + * This method reads configuration from the standard AWS environment variables used + * by container credential providers and constructs the appropriate endpoint configuration. + * + * Environment variables checked (in priority order): + * - AWS_CONTAINER_CREDENTIALS_RELATIVE_URI: Relative URI path (used with ECS metadata endpoint) + * - AWS_CONTAINER_CREDENTIALS_FULL_URI: Complete URI (used with custom endpoints like EKS Pod Identity) + * - AWS_CONTAINER_AUTHORIZATION_TOKEN: Direct authorization token for requests + * - AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE: Path to file containing authorization token + * + * When AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is present, the full endpoint URI is constructed + * as: http://169.254.170.2 + relative_uri (ECS metadata endpoint) + * + * When AWS_CONTAINER_CREDENTIALS_FULL_URI is present, it is used as-is (typically for EKS Pod Identity) + * + * @return F[Option[ContainerCredentialsConfig]] Container configuration if environment variables are properly set, None otherwise + */ + private def loadContainerCredentialsConfig(): F[Option[ContainerCredentialsConfig]] = + for + relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + directToken <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN") + tokenFile <- Env[F].get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") + token <- loadAuthorizationToken(directToken, tokenFile) + yield (relativeUri, fullUri) match { + case (Some(relative), _) => + Some( + ContainerCredentialsConfig( + endpointUri = s"http://169.254.170.2$relative", + authorizationToken = token + ) + ) + case (_, Some(full)) => + Some( + ContainerCredentialsConfig( + endpointUri = full, + authorizationToken = token + ) + ) + case _ => None + } + + /** + * Loads the authorization token from environment variables or file. + * + * This method resolves the authorization token needed for container credential requests + * by checking both direct token environment variable and token file path. + * + * Token resolution priority: + * 1. AWS_CONTAINER_AUTHORIZATION_TOKEN (direct token value) + * 2. AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE (path to file containing token) + * + * The authorization token is typically used in EKS environments where the token + * is provided by the Kubernetes service account system. + * + * @param directToken Optional token value from AWS_CONTAINER_AUTHORIZATION_TOKEN + * @param tokenFilePath Optional file path from AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE + * @return F[Option[String]] The authorization token if available, None otherwise + */ + private def loadAuthorizationToken( + directToken: Option[String], + tokenFilePath: Option[String] + ): F[Option[String]] = + (directToken, tokenFilePath) match { + case (Some(token), _) => + Concurrent[F].pure(Some(token.trim).filter(_.nonEmpty)) + case (_, Some(filePath)) => + loadTokenFromFile(Path(filePath)) + case _ => + Concurrent[F].pure(None) + } + + /** + * Loads authorization token from a file path. + * + * This method safely reads the authorization token from the specified file path, + * handling potential file system errors gracefully. This is commonly used in + * Kubernetes environments where tokens are mounted as files in the pod. + * + * The method performs the following operations: + * 1. Check if the file exists + * 2. Read the file content as UTF-8 text + * 3. Trim whitespace and validate the token is non-empty + * 4. Handle any file I/O errors by returning None + * + * Common token file locations: + * - /var/run/secrets/eks.amazonaws.com/serviceaccount/token (EKS IRSA) + * - /var/run/secrets/kubernetes.io/serviceaccount/token (Standard Kubernetes SA token) + * + * @param tokenFilePath Path to the token file + * @return F[Option[String]] The token content if file exists and is readable, None otherwise + */ + private def loadTokenFromFile(tokenFilePath: Path): F[Option[String]] = + Files[F].exists(tokenFilePath).flatMap { exists => + if exists then { + Files[F] + .readUtf8(tokenFilePath) + .compile + .string + .map(_.trim) + .map(token => if token.nonEmpty then Some(token) else None) + .handleErrorWith { _ => + Concurrent[F].pure(None) + } + } else { + Concurrent[F].pure(None) + } + } + + /** + * Fetches AWS credentials from the configured container endpoint. + * + * This method performs the actual HTTP request to retrieve temporary AWS credentials + * from the container credential provider endpoint. The request includes appropriate + * headers and authorization tokens when required. + * + * Request flow: + * 1. Build HTTP headers with authorization token (if present) + * 2. Make HTTP GET request to the credential endpoint + * 3. Validate the HTTP response status + * 4. Parse the JSON response to extract credentials + * + * The endpoint returns a JSON response containing: + * - AccessKeyId: Temporary access key + * - SecretAccessKey: Temporary secret key + * - Token: Session token + * - Expiration: RFC3339 formatted expiration timestamp + * - RoleArn: ARN of the assumed role (optional) + * + * @param config Container credentials configuration with endpoint URI and auth token + * @return F[AwsCredentials] The temporary AWS credentials from the endpoint + * @throws SdkClientException if the HTTP request fails or response parsing fails + */ + private def fetchCredentialsFromEndpoint(config: ContainerCredentialsConfig): F[AwsCredentials] = + val headers = buildRequestHeaders(config.authorizationToken) + for + response <- httpClient.get(URI.create(config.endpointUri), headers) + _ <- validateHttpResponse(response) + credentials <- parseCredentialsResponse(response.body) + yield credentials + + /** + * Builds HTTP request headers for container credential requests. + * + * This method constructs the necessary HTTP headers for credential endpoint requests, + * including content type specification and authorization token when available. + * + * Standard headers included: + * - Accept: application/json (indicates we expect JSON response) + * - User-Agent: aws-sdk-scala/ldbc (identifies the client SDK) + * - Authorization: token value (when authorization token is available) + * + * The Authorization header is only included when an authorization token is provided, + * which is typically required for EKS Pod Identity but not for ECS Task Roles. + * + * @param authToken Optional authorization token for the request + * @return Map[String, String] HTTP headers for the credential request + */ + private def buildRequestHeaders(authToken: Option[String]): Map[String, String] = + val baseHeaders = Map( + "Accept" -> "application/json", + "User-Agent" -> "aws-sdk-scala/ldbc" + ) + authToken match { + case Some(token) => baseHeaders + ("Authorization" -> token) + case None => baseHeaders + } + + /** + * Validates the HTTP response from the container credential endpoint. + * + * This method checks that the credential endpoint returned a successful HTTP status code + * (2xx range) and raises an exception for any error responses. + * + * Common error scenarios: + * - 401 Unauthorized: Invalid or missing authorization token + * - 403 Forbidden: Insufficient permissions for the requested role + * - 404 Not Found: Invalid endpoint URI or credential path + * - 500 Internal Server Error: Credential service internal error + * + * The error message includes both the HTTP status code and response body + * to provide detailed information for debugging credential issues. + * + * @param response HTTP response from the credential endpoint + * @return F[Unit] Success if status code is 2xx, failure otherwise + * @throws SdkClientException if the HTTP status code indicates an error + */ + private def validateHttpResponse(response: ldbc.amazon.client.HttpResponse): F[Unit] = + if response.statusCode >= 200 && response.statusCode < 300 then { + Concurrent[F].unit + } else { + Concurrent[F].raiseError( + new SdkClientException( + s"Container credentials request failed with status ${ response.statusCode }: ${ response.body }" + ) + ) + } + + /** + * Parses the JSON credentials response from the container endpoint. + * + * This method deserializes the JSON response containing temporary AWS credentials + * and converts it into an AwsSessionCredentials instance. The response follows + * the standard AWS credential response format. + * + * Expected JSON structure: + * ```json + * { + * "AccessKeyId": "ASIA...", + * "SecretAccessKey": "...", + * "Token": "...", + * "Expiration": "2024-01-15T12:00:00Z", + * "RoleArn": "arn:aws:iam::123456789012:role/MyRole" // optional + * } + * ``` + * + * The method performs the following operations: + * 1. Parse the JSON response using SimpleJsonParser + * 2. Extract required fields (AccessKeyId, SecretAccessKey, Token, Expiration) + * 3. Parse the expiration timestamp in RFC3339 format + * 4. Extract the AWS account ID from the role ARN (if present) + * 5. Create AwsSessionCredentials with provider metadata + * + * @param jsonBody JSON response body from the credential endpoint + * @return F[AwsCredentials] Parsed AWS session credentials + * @throws SdkClientException if JSON parsing fails or required fields are missing + */ + private def parseCredentialsResponse(jsonBody: String): F[AwsCredentials] = + Concurrent[F] + .fromEither( + SimpleJsonParser + .parse(jsonBody) + .flatMap(ContainerCredentialsResponse.fromJson) + .left + .map(msg => new SdkClientException(s"Failed to parse JSON: $msg")) + ) + .map { response => + AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_CONTAINER.code), + accountId = extractAccountIdFromRoleArn(response.RoleArn), + expirationTime = Some(Instant.parse(response.Expiration)) + ) + } + .adaptError { + case ex => + new SdkClientException(s"Failed to parse container credentials response: ${ ex.getMessage }") + } + + /** + * Extracts the AWS account ID from an IAM role ARN. + * + * This method parses the role ARN returned by container credential endpoints + * to extract the AWS account ID, which is useful for credential metadata. + * + * ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + * Example: arn:aws:iam::123456789012:role/EKS-Pod-Identity-Role + * + * The method splits the ARN by colons and extracts the account ID from the 5th position + * (0-indexed position 4). If the ARN format is invalid or the account ID position + * is not available, None is returned. + * + * This information is particularly useful for: + * - Auditing and logging credential usage + * - Validating that credentials are from the expected AWS account + * - Supporting multi-account AWS environments + * + * @param roleArn Optional IAM role ARN from the credential response + * @return Option[String] The AWS account ID if extractable from the ARN, None otherwise + */ + private def extractAccountIdFromRoleArn(roleArn: Option[String]): Option[String] = + roleArn.flatMap { arn => + // ARN format: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + val arnParts = arn.split(":") + if arnParts.length >= 5 then { + Some(arnParts(4)) + } else { + None + } + } + +/** + * Configuration for container credentials endpoint. + * + * @param endpointUri The full URI to the credential endpoint + * @param authorizationToken Optional authorization token for requests + */ +private case class ContainerCredentialsConfig( + endpointUri: String, + authorizationToken: Option[String] +) + +/** + * JSON response from container credentials endpoint. + * + * @param AccessKeyId AWS access key ID + * @param SecretAccessKey AWS secret access key + * @param Token Session token for temporary credentials + * @param Expiration RFC3339 formatted expiration timestamp + * @param RoleArn Optional ARN of the assumed role + */ +private case class ContainerCredentialsResponse( + AccessKeyId: String, + SecretAccessKey: String, + Token: String, + Expiration: String, + RoleArn: Option[String] = None +) + +/** + * Companion object for parsing JSON responses from container credential endpoints. + */ +private object ContainerCredentialsResponse: + + /** + * Parses a JSON object into a ContainerCredentialsResponse. + * + * This method deserializes the standard AWS container credentials JSON response format, + * extracting all required fields and optional metadata. + * + * Required JSON fields: + * - AccessKeyId: The temporary AWS access key ID + * - SecretAccessKey: The temporary AWS secret access key + * - Token: The session token for temporary credentials + * - Expiration: RFC3339 formatted timestamp indicating when credentials expire + * + * Optional JSON fields: + * - RoleArn: The ARN of the IAM role that was assumed to generate these credentials + * + * @param json The parsed JSON object from the credential endpoint response + * @return Either[String, ContainerCredentialsResponse] Success with parsed response or error message + */ + def fromJson(json: SimpleJsonParser.JsonObject): Either[String, ContainerCredentialsResponse] = + for + accessKeyId <- json.require("AccessKeyId") + secretAccessKey <- json.require("SecretAccessKey") + token <- json.require("Token") + expiration <- json.require("Expiration") + yield ContainerCredentialsResponse( + AccessKeyId = accessKeyId, + SecretAccessKey = secretAccessKey, + Token = token, + Expiration = expiration, + RoleArn = json.get("RoleArn") + ) + +object ContainerCredentialsProvider: + + /** + * Creates a new Container credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for credential requests + * @tparam F The effect type + * @return A new ContainerCredentialsProvider instance + */ + def create[F[_]: Files: Env: Concurrent]( + httpClient: HttpClient[F] + ): ContainerCredentialsProvider[F] = + new ContainerCredentialsProvider[F](httpClient) + + /** + * Checks if Container credentials are available by verifying + * the presence of required environment variables. + * + * @tparam F The effect type + * @return true if Container credentials are properly configured + */ + def isAvailable[F[_]: Env: Concurrent](): F[Boolean] = + for + relativeUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") + fullUri <- Env[F].get("AWS_CONTAINER_CREDENTIALS_FULL_URI") + yield relativeUri.exists(_.trim.nonEmpty) || fullUri.exists(_.trim.nonEmpty) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala new file mode 100644 index 000000000..baa345757 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/DefaultCredentialsProviderChain.scala @@ -0,0 +1,185 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import scala.concurrent.duration.* + +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties, UUIDGen } + +import fs2.io.file.Files +import fs2.io.net.* + +import ldbc.amazon.client.* +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* + +/** + * Default AWS credentials provider chain that matches AWS SDK v2 behavior. + * + * The provider chain attempts to resolve credentials from the following sources in order: + * 1. Java system properties (aws.accessKeyId and aws.secretAccessKey) + * 2. Environment variables (AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY) + * 3. Web identity token file (AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN) + * 4. AWS credentials profile files (~/.aws/credentials) + * 5. Container credentials provider (ECS/EKS) + * 6. Instance profile credentials provider (EC2) + * + * The first provider in the chain that successfully provides credentials will be used, + * and the search will stop there. If no provider can provide credentials, an exception is thrown. + * + * Usage example: + * ```scala + * import cats.effect.IO + * + * val credentialsProvider = DefaultCredentialsProviderChain[IO]() + * val credentials: IO[AwsCredentials] = credentialsProvider.resolveCredentials() + * ``` + * + * @tparam F The effect type + */ +class DefaultCredentialsProviderChain[F[_]: Files: Env: SystemProperties: UUIDGen: Concurrent]( + httpClient: HttpClient[F], + region: String +) extends AwsCredentialsProvider[F]: + + /** + * Lazily initialized list of AWS credential providers in order of precedence. + * + * This method creates the standard AWS SDK credential provider chain that matches + * the behavior of AWS SDK for Java v2. The providers are ordered by precedence, + * with more specific/explicit credential sources taking priority over implicit ones. + * + * Provider chain order: + * 1. SystemPropertyCredentialsProvider - Java system properties (aws.accessKeyId, aws.secretAccessKey) + * 2. EnvironmentVariableCredentialsProvider - Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) + * 3. WebIdentityTokenFileCredentialsProvider - OIDC token authentication (Kubernetes IRSA) + * 4. ProfileCredentialsProvider - AWS credentials file (~/.aws/credentials) + * 5. ContainerCredentialsProvider - ECS task roles and EKS pod identity + * 6. InstanceProfileCredentialsProvider - EC2 instance profile (IMDS) + * + * Each provider is only consulted if the previous providers fail to provide credentials. + * This ordering ensures that explicit credentials (environment variables, system properties) + * take precedence over implicit credentials (instance profiles, container roles). + * + * Performance considerations: + * - ProfileCredentialsProvider and InstanceProfileCredentialsProvider require initialization + * - The list is lazily computed to avoid unnecessary initialization overhead + * - Failed providers are cached to avoid repeated initialization attempts + * + * @return F[List[AwsCredentialsProvider[F]]] The ordered list of credential providers + */ + private lazy val providers: F[List[AwsCredentialsProvider[F]]] = + for + profileProvider <- ProfileCredentialsProvider.default[F]() + instanceProfileCredentialsProvider <- InstanceProfileCredentialsProvider.create[F](httpClient) + yield List( + new SystemPropertyCredentialsProvider[F](), + new EnvironmentVariableCredentialsProvider[F](), + WebIdentityTokenFileCredentialsProvider.default[F](httpClient, region), + profileProvider, + ContainerCredentialsProvider.create[F](httpClient), + instanceProfileCredentialsProvider + ) + + /** + * Resolves AWS credentials by trying providers in the default chain order. + * + * This method implements the AWS SDK v2 default credential provider chain behavior, + * attempting to resolve credentials from each provider in sequence until one succeeds + * or all providers are exhausted. + * + * Resolution process: + * 1. Initialize the provider list (lazy evaluation) + * 2. Try each provider in order, starting with highest precedence + * 3. Return credentials from the first successful provider + * 4. If all providers fail, throw an exception with aggregated error messages + * + * The first successful provider "wins" and the chain stops there. This behavior + * ensures predictable credential resolution and prevents unexpected credential + * source changes during application runtime. + * + * Common resolution scenarios: + * - Development: Environment variables or system properties + * - CI/CD: Web Identity tokens or environment variables + * - ECS: Container credentials provider + * - EKS: Web Identity tokens (IRSA) or container credentials + * - EC2: Instance profile credentials + * - Local AWS CLI: Profile credentials from ~/.aws/credentials + * + * @return F[AwsCredentials] The resolved AWS credentials from the first successful provider + * @throws SdkClientException if no provider in the chain can provide valid credentials + */ + override def resolveCredentials(): F[AwsCredentials] = + for + providerList <- providers + credentials <- tryProvidersInOrder(providerList, Nil) + yield credentials + + /** + * Attempts to resolve credentials from providers in order, handling failures gracefully. + * + * This method implements the recursive credential resolution logic for the provider chain. + * It tries each provider sequentially and accumulates error messages for debugging purposes. + * + * Error handling strategy: + * - Each provider failure is caught and logged for debugging + * - Failures are expected and normal behavior (e.g., no ~/.aws/credentials file) + * - Only when all providers fail is an exception raised + * - Error messages from all providers are included in the final exception + * + * The method maintains a list of error messages to provide comprehensive debugging + * information when credential resolution completely fails. This helps developers + * understand why credential resolution failed across all providers. + * + * Example error scenarios: + * - SystemPropertyCredentialsProvider: "aws.accessKeyId system property not set" + * - EnvironmentVariableCredentialsProvider: "AWS_ACCESS_KEY_ID environment variable not set" + * - ProfileCredentialsProvider: "Unable to load credentials from profile 'default'" + * - InstanceProfileCredentialsProvider: "Unable to retrieve credentials from IMDS" + * + * @param providers Remaining providers to try in the chain + * @param exceptions Accumulated error messages from failed providers + * @return F[AwsCredentials] The credentials from the first successful provider + * @throws SdkClientException if all providers in the list fail to provide credentials + */ + private def tryProvidersInOrder( + providers: List[AwsCredentialsProvider[F]], + exceptions: List[String] + ): F[AwsCredentials] = + providers match + case Nil => + Concurrent[F].raiseError( + new SdkClientException( + s"Unable to load AWS credentials from any provider in the chain: ${ exceptions.mkString(", ") }" + ) + ) + + case provider :: remainingProviders => + provider.resolveCredentials().recoverWith { ex => + val errorMsg = s"${ provider.getClass.getSimpleName }: ${ ex.getMessage }" + tryProvidersInOrder(remainingProviders, exceptions :+ errorMsg) + } + +object DefaultCredentialsProviderChain: + + /** + * Creates a new default credentials provider chain. + * + * @tparam F The effect type + * @return A new DefaultCredentialsProviderChain instance + */ + def default[F[_]: Files: Env: SystemProperties: Network: UUIDGen: Async]( + region: String + ): DefaultCredentialsProviderChain[F] = + val httpClient = SimpleHttpClient[F]( + connectTimeout = 1.second, + readTimeout = 2.seconds + ) + new DefaultCredentialsProviderChain[F](httpClient, region) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala new file mode 100644 index 000000000..bf0e282e4 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProvider.scala @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.MonadThrow + +import cats.effect.std.Env + +import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SdkSystemSetting + +/** + * AWS credentials provider that loads credentials from environment variables. + * + * This provider looks for AWS credentials in the following environment variables: + * - AWS_ACCESS_KEY_ID: The AWS access key ID + * - AWS_SECRET_ACCESS_KEY: The AWS secret access key + * - AWS_SESSION_TOKEN: Optional session token for temporary credentials + * + * This provider only checks environment variables and does not fallback to system properties. + * Customers can use this provider when they want to explicitly source credentials from + * environment variables only. + * + * @tparam F The effect type that supports environment variable access and error handling + */ +final class EnvironmentVariableCredentialsProvider[F[_]: Env: MonadThrow] extends SystemSettingsCredentialsProvider[F]: + + /** + * Loads a setting value from environment variables only. + * + * This implementation specifically looks at environment variables and does not check + * system properties, allowing customers to specify a credentials provider that only + * uses environment variables. + * + * @param setting The system setting to load + * @return An effect containing the optional setting value from environment variables + */ + override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = Env[F].get(setting.toString) + + /** + * Returns the provider identifier for business metrics tracking. + * + * @return The business metric feature ID for environment variable credentials + */ + override def provider: String = BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala new file mode 100644 index 000000000..aa7275b97 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProvider.scala @@ -0,0 +1,437 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import scala.concurrent.duration.* + +import cats.syntax.all.* + +import cats.effect.{ Async, Concurrent, Ref } +import cats.effect.std.Env + +import fs2.io.net.* + +import ldbc.amazon.client.* +import ldbc.amazon.client.{ HttpClient, SimpleHttpClient } +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SimpleJsonParser + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from EC2 Instance Metadata Service (IMDS). + * + * This provider is used on Amazon EC2 instances with attached IAM instance profiles. + * It automatically retrieves temporary credentials from the EC2 metadata service at: + * http://169.254.169.254/latest/meta-data/iam/security-credentials/ + * + * The provider supports both IMDSv1 (legacy) and IMDSv2 (recommended) protocols: + * - IMDSv2: Uses session tokens for enhanced security against SSRF attacks + * - IMDSv1: Fallback mode for backward compatibility + * + * Environment variables: + * - AWS_EC2_METADATA_DISABLED: Set to 'true' to disable IMDS credential provider + * - AWS_EC2_METADATA_SERVICE_ENDPOINT: Override default IMDS endpoint (for testing) + * + * Authentication flow: + * 1. Optionally acquire IMDSv2 session token (PUT request with TTL header) + * 2. List available IAM roles from metadata service + * 3. Retrieve credentials for the first available role + * 4. Cache credentials until 4 minutes before expiration + * 5. Automatically refresh credentials in the background + * + * @param httpClient HTTP client for metadata service requests + * @param credentialsRef Mutable reference for credential caching + * @tparam F The effect type + */ +final class InstanceProfileCredentialsProvider[F[_]: Env: Concurrent]( + httpClient: HttpClient[F], + credentialsRef: Ref[F, Option[CachedCredentials]] +) extends AwsCredentialsProvider[F]: + + private val DEFAULT_IMD_SEND_POINT = "http://169.254.169.254" + private val METADATA_TOKEN_TTL_SECONDS = 21600 // 6 hours + private val CREDENTIAL_REFRESH_BUFFER = 4.minutes + + /** + * Resolves AWS credentials from EC2 instance metadata service. + * + * This method performs the following steps: + * 1. Checks if the metadata service is disabled via environment variable + * 2. Retrieves cached credentials if they are still valid + * 3. Refreshes credentials from IMDS if needed + * + * @return AWS credentials from the instance profile + * @throws SdkClientException if metadata service is disabled or unreachable + */ + override def resolveCredentials(): F[AwsCredentials] = + for + disabled <- checkIfDisabled() + _ <- Concurrent[F].raiseWhen(disabled)( + new SdkClientException("EC2 metadata service is disabled via AWS_EC2_METADATA_DISABLED") + ) + cached <- credentialsRef.get + credentials <- cached match { + case Some(creds) if !isExpiringSoon(creds) => + Concurrent[F].pure(creds.credentials) + case _ => + refreshCredentials() + } + yield credentials + + /** + * Checks if the EC2 metadata service is disabled via environment variable. + * + * The AWS_EC2_METADATA_DISABLED environment variable can be set to "true" + * to disable this credentials provider. + * + * @return true if the metadata service is disabled, false otherwise + */ + private def checkIfDisabled(): F[Boolean] = + Env[F].get("AWS_EC2_METADATA_DISABLED").map { + case Some(value) => value.toLowerCase == "true" + case None => false + } + + /** + * Refreshes credentials from the EC2 instance metadata service. + * + * This method performs the complete IMDS authentication flow: + * 1. Gets the IMDS endpoint (default or custom) + * 2. Attempts to acquire an IMDSv2 token (falls back to IMDSv1 if it fails) + * 3. Retrieves the available IAM role name + * 4. Fetches credentials for the role + * 5. Caches the credentials with timestamp + * + * @return Fresh AWS credentials from IMDS + */ + private def refreshCredentials(): F[AwsCredentials] = + for + endpoint <- getImdsEndpoint() + token <- acquireMetadataToken(endpoint).attempt.map(_.toOption) + roleName <- getRoleName(endpoint, token) + credentials <- getCredentialsForRole(endpoint, token, roleName) + cached = CachedCredentials(credentials, Instant.now()) + _ <- credentialsRef.set(Some(cached)) + yield credentials + + /** + * Gets the IMDS endpoint URL from environment variable or default. + * + * The AWS_EC2_METADATA_SERVICE_ENDPOINT environment variable can be used + * to override the default IMDS endpoint (useful for testing or proxies). + * + * @return The IMDS endpoint URL without trailing slash + */ + private def getImdsEndpoint(): F[String] = + Env[F].get("AWS_EC2_METADATA_SERVICE_ENDPOINT").flatMap { + case Some(endpoint) => + Concurrent[F] + .raiseUnless(endpoint.matches("^https?://169\\.254\\.169\\.254.*"))( + new SecurityException("Invalid IMDS endpoint") + ) + .map(_ => endpoint.stripSuffix("/")) + case None => Concurrent[F].pure(DEFAULT_IMD_SEND_POINT) + } + + /** + * Acquires an IMDSv2 session token for enhanced security. + * + * IMDSv2 requires a session token to access metadata, which helps prevent + * Server-Side Request Forgery (SSRF) attacks. The token is obtained via + * a PUT request with a TTL header. + * + * @param endpoint The IMDS endpoint URL + * @return A session token valid for the configured TTL period + * @throws SdkClientException if token acquisition fails + */ + private def acquireMetadataToken(endpoint: String): F[String] = + val tokenUrl = s"$endpoint/latest/api/token" + val headers = Map( + "X-aws-ec2-metadata-token-ttl-seconds" -> METADATA_TOKEN_TTL_SECONDS.toString + ) + + for + response <- httpClient.put(URI.create(tokenUrl), headers, "") + _ <- validateHttpResponse(response, "Failed to acquire metadata token") + yield response.body.trim + + /** + * Retrieves the IAM role name associated with the instance. + * + * This method lists all available IAM roles from the metadata service + * and selects the first one. In practice, EC2 instances typically have + * only one attached IAM role. + * + * @param endpoint The IMDS endpoint URL + * @param token Optional IMDSv2 session token + * @return The name of the IAM role attached to the instance + * @throws SdkClientException if no roles are available or request fails + */ + private def getRoleName(endpoint: String, token: Option[String]): F[String] = + val roleUrl = s"$endpoint/latest/meta-data/iam/security-credentials/" + val headers = buildRequestHeaders(token) + + for + response <- httpClient.get(URI.create(roleUrl), headers) + _ <- validateHttpResponse(response, "Failed to list IAM roles") + roleName <- parseRoleListResponse(response.body) + yield roleName + + /** + * Retrieves AWS credentials for the specified IAM role. + * + * This method fetches the temporary credentials (access key, secret key, + * and session token) associated with the IAM role. The credentials are + * returned in JSON format by the metadata service. + * + * @param endpoint The IMDS endpoint URL + * @param token Optional IMDSv2 session token + * @param roleName The name of the IAM role + * @return AWS credentials for the specified role + * @throws SdkClientException if credentials cannot be retrieved or parsed + */ + private def getCredentialsForRole( + endpoint: String, + token: Option[String], + roleName: String + ): F[AwsCredentials] = + val credentialsUrl = s"$endpoint/latest/meta-data/iam/security-credentials/$roleName" + val headers = buildRequestHeaders(token) + + for + response <- httpClient.get(URI.create(credentialsUrl), headers) + _ <- validateHttpResponse(response, s"Failed to retrieve credentials for role $roleName") + credentials <- parseCredentialsResponse(response.body, roleName) + yield credentials + + /** + * Builds HTTP request headers for IMDS requests. + * + * Creates appropriate headers for both IMDSv1 and IMDSv2 requests. + * When an IMDSv2 token is available, it's included in the headers. + * + * @param token Optional IMDSv2 session token + * @return Map of HTTP headers for the request + */ + private def buildRequestHeaders(token: Option[String]): Map[String, String] = + val baseHeaders = Map( + "Accept" -> "application/json", + "User-Agent" -> "aws-sdk-scala/ldbc" + ) + token match { + case Some(t) => baseHeaders + ("X-aws-ec2-metadata-token" -> t) + case None => baseHeaders + } + + /** + * Validates HTTP response status codes and provides appropriate error messages. + * + * This method interprets common HTTP status codes in the context of IMDS operations + * and provides meaningful error messages for debugging authentication issues. + * + * @param response The HTTP response to validate + * @param context A descriptive context for error messages + * @return Unit if the response is successful + * @throws SdkClientException with context-specific error message for failed responses + */ + private def validateHttpResponse(response: HttpResponse, context: String): F[Unit] = + response.statusCode match { + case code if code >= 200 && code < 300 => Concurrent[F].unit + case 401 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Unauthorized (401) - Invalid or expired metadata token") + ) + case 403 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Forbidden (403) - No instance profile attached") + ) + case 404 => + Concurrent[F].raiseError( + new SdkClientException(s"$context: Not Found (404) - Instance metadata not available") + ) + case code => + Concurrent[F].raiseError( + new SdkClientException(s"$context: HTTP $code - ${ response.body }") + ) + } + + private def parseRoleListResponse(body: String): F[String] = + val roles = body.trim.split('\n').map(_.trim).filter(_.nonEmpty) + roles.headOption match { + case Some(roleName) => Concurrent[F].pure(roleName) + case None => + Concurrent[F].raiseError( + new SdkClientException("No IAM roles found in instance metadata") + ) + } + + private def parseCredentialsResponse(jsonBody: String, roleName: String): F[AwsCredentials] = + Concurrent[F] + .fromEither( + SimpleJsonParser + .parse(jsonBody) + .flatMap(InstanceMetadataCredentialsResponse.fromJson) + .left + .map(msg => new SdkClientException(s"Failed to parse JSON: $msg")) + ) + .flatMap { response => + if response.Code == "Success" then { + Concurrent[F].pure( + AwsSessionCredentials( + accessKeyId = response.AccessKeyId, + secretAccessKey = response.SecretAccessKey, + sessionToken = response.Token, + validateCredentials = false, + providerName = Some(BusinessMetricFeatureId.CREDENTIALS_IMDS.code), + accountId = extractAccountIdFromArn(response.AccessKeyId), + expirationTime = Some(Instant.parse(response.Expiration)) + ) + ) + } else { + Concurrent[F].raiseError( + new SdkClientException(s"Failed to retrieve credentials for role $roleName: ${ response.Code }") + ) + } + } + .adaptError { + case ex => + new SdkClientException(s"Failed to parse instance metadata credentials response: ${ ex.getMessage }") + } + + private def extractAccountIdFromArn(accessKeyId: String): Option[String] = + // For instance profile credentials, we don't have the account ID directly + // The access key ID pattern is AKIA... for long-term or ASIA... for temporary credentials + None + + private def isExpiringSoon(cached: CachedCredentials): Boolean = + cached.credentials.expirationTime match { + case Some(expiration) => + val now = Instant.now() + val bufferTime = expiration.minusSeconds(CREDENTIAL_REFRESH_BUFFER.toSeconds) + now.isAfter(bufferTime) + case None => false + } + +/** + * Cached credentials with retrieval timestamp. + */ +private case class CachedCredentials( + credentials: AwsCredentials, + retrievedAt: Instant +) + +/** + * JSON response from EC2 instance metadata service. + */ +private case class InstanceMetadataCredentialsResponse( + Code: String, + LastUpdated: String, + Type: String, + AccessKeyId: String, + SecretAccessKey: String, + Token: String, + Expiration: String +) + +private object InstanceMetadataCredentialsResponse: + + def fromJson(json: SimpleJsonParser.JsonObject): Either[String, InstanceMetadataCredentialsResponse] = + for + code <- json.require("Code") + accessKeyId <- json.require("AccessKeyId") + secretAccessKey <- json.require("SecretAccessKey") + token <- json.require("Token") + expiration <- json.require("Expiration") + yield InstanceMetadataCredentialsResponse( + Code = code, + LastUpdated = json.getOrEmpty("LastUpdated"), + Type = json.getOrEmpty("Type"), + AccessKeyId = accessKeyId, + SecretAccessKey = secretAccessKey, + Token = token, + Expiration = expiration + ) + +object InstanceProfileCredentialsProvider: + + /** + * Creates a new Instance Profile credentials provider with default settings. + * + * @tparam F The effect type + * @return A new InstanceProfileCredentialsProvider instance + */ + def apply[F[_]: Env: Network: Async](): F[InstanceProfileCredentialsProvider[F]] = + for + httpClient <- createDefaultHttpClient[F]() + credentialsRef <- Ref.of[F, Option[CachedCredentials]](None) + yield new InstanceProfileCredentialsProvider[F](httpClient, credentialsRef) + + /** + * Creates a new Instance Profile credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for metadata service requests + * @tparam F The effect type + * @return A new InstanceProfileCredentialsProvider instance + */ + def create[F[_]: Env: Concurrent]( + httpClient: HttpClient[F] + ): F[InstanceProfileCredentialsProvider[F]] = + Ref.of[F, Option[CachedCredentials]](None).map { credentialsRef => + new InstanceProfileCredentialsProvider[F](httpClient, credentialsRef) + } + + /** + * Creates a default HTTP client optimized for EC2 metadata service. + * + * @tparam F The effect type + * @return A configured HTTP client + */ + private def createDefaultHttpClient[F[_]: Network: Async](): F[HttpClient[F]] = + Async[F].pure( + SimpleHttpClient[F]( + connectTimeout = 2.seconds, + readTimeout = 5.seconds + ) + ) + + /** + * Checks if Instance Profile credentials are available by attempting + * to connect to the EC2 metadata service. + * + * @tparam F The effect type + * @return true if the metadata service is reachable + */ + def isAvailable[F[_]: Env: Network: Async](): F[Boolean] = + for + disabled <- Env[F].get("AWS_EC2_METADATA_DISABLED").map { + case Some(value) => value.toLowerCase == "true" + case None => false + } + available <- if disabled then { + Async[F].pure(false) + } else { + checkMetadataServiceAvailability[F]() + } + yield available + + private def checkMetadataServiceAvailability[F[_]: Network: Async](): F[Boolean] = + val httpClient = SimpleHttpClient[F]( + connectTimeout = 1.second, + readTimeout = 2.seconds + ) + + httpClient + .get( + URI.create("http://169.254.169.254/latest/meta-data/"), + Map.empty + ) + .map(_.statusCode == 200) + .handleErrorWith(_ => Async[F].pure(false)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala new file mode 100644 index 000000000..cfae8cbe1 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProvider.scala @@ -0,0 +1,211 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import cats.syntax.all.* +import cats.MonadThrow + +import cats.effect.* +import cats.effect.std.* + +import fs2.io.file.{ Files, Path } + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* + +import ProfileCredentialsProvider.* + +/** + * AWS credentials provider that loads credentials from the AWS credentials file. + * + * This provider reads credentials from the standard AWS credentials file located at + * `~/.aws/credentials` and supports profile-based credential management. The credentials + * file uses INI format with profile sections containing access keys and other configuration. + * + * The provider implements caching and file change detection to avoid unnecessary file I/O + * and provides thread-safe access to credentials through semaphore-based synchronization. + * + * File format example: + * ``` + * [default] + * aws_access_key_id = AKIAIOSFODNN7EXAMPLE + * aws_secret_access_key = wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + * + * [production] + * aws_access_key_id = AKIAIOSFODNN7EXAMPLE + * aws_secret_access_key = wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + * aws_session_token = IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMQ== + * ``` + * + * @param profileName The name of the profile to load credentials from + * @param cacheRef Mutable reference for caching parsed credentials and file metadata + * @param semaphore Semaphore for controlling concurrent access to file operations + * @tparam F The effect type that supports file operations, system properties, and concurrency + */ +final class ProfileCredentialsProvider[F[_]: SystemProperties: Files: Concurrent]( + profileName: String, + cacheRef: Ref[F, Option[(ProfileFile, AwsCredentials)]], + semaphore: Semaphore[F] +)(using ev: MonadThrow[F]) + extends AwsCredentialsProvider[F]: + + /** + * Resolves AWS credentials from the specified profile in the credentials file. + * + * This method implements intelligent caching by checking if the credentials file + * has been modified since the last read. If the file is unchanged, cached credentials + * are returned. Otherwise, the file is re-parsed and credentials are updated. + * + * @return AWS credentials from the specified profile + * @throws SdkClientException if the credentials file is not found, the profile + * doesn't exist, or required fields are missing + */ + override def resolveCredentials(): F[AwsCredentials] = + for + currentFile <- loadFile + cached <- cacheRef.get + credentials <- cached match + case Some((cachedFile, creds)) if cachedFile.lastModified == currentFile.lastModified => + ev.pure(creds) + case _ => updateCredentials(currentFile) + yield credentials + + /** + * Loads and parses the AWS credentials file from the user's home directory. + * + * This method locates the credentials file at `~/.aws/credentials`, reads its contents, + * and parses the INI-format configuration into profile structures. File metadata + * including last modified time is tracked for cache invalidation. + * + * @return Parsed credentials file with profiles and metadata + * @throws SdkClientException if the home directory cannot be determined, the file + * doesn't exist, or the file format is invalid + */ + private def loadFile: F[ProfileFile] = + for + homeOpt <- SystemProperties[F].get("user.home") + home <- ev.fromOption(homeOpt, new SdkClientException("")) + credentialsPath = Path(s"$home/.aws/credentials") + exists <- Files[F].exists(credentialsPath) + _ <- ev.raiseUnless(exists)(new SdkClientException(s"File not found: $credentialsPath")) + content <- Files[F].readUtf8(credentialsPath).compile.string + lastMod <- Files[F].getLastModifiedTime(credentialsPath) + profiles <- ev.fromEither(parseProfiles(content)) + yield ProfileFile(profiles, Instant.ofEpochMilli(lastMod.toMillis)) + + /** + * Parses the INI-format credentials file content into profile structures. + * + * This method processes the credentials file line by line, extracting profile + * sections and their properties. It supports both `[profile name]` and `[name]` + * section headers and handles various property formats. + * + * @param content The raw content of the credentials file + * @return Either an error if parsing fails, or a map of profile names to profile data + */ + private def parseProfiles(content: String): Either[Throwable, Map[String, Profile]] = + val profilePattern = "\\[(?:profile\\s+)?(.+)\\]".r + val propertyPattern = "^\\s*([^=]+)\\s*=\\s*(.+)\\s*$".r + + val lines = content.linesIterator.toList + + case class State( + currentProfile: Option[String] = None, + profiles: Map[String, Map[String, String]] = Map.empty + ) + + val finalState = lines.foldLeft(State()): (state, line) => + line.trim match + case _ if line.trim.isEmpty || line.trim.startsWith("#") => + state + case profilePattern(name) => + State(Some(name.trim), state.profiles + (name.trim -> Map.empty)) + case propertyPattern(key, value) => + state.currentProfile match + case Some(profile) => + val updatedProps = state.profiles.getOrElse(profile, Map.empty) + (key.trim -> value.trim) + state.copy(profiles = state.profiles + (profile -> updatedProps)) + case None => state + case _ => state + + Right(finalState.profiles.map: (name, props) => + name -> Profile(name, props)) + + private def parseStaticCredentials(props: Map[String, String]): Option[AwsCredentials] = + for + accessKeyId <- props.get("aws_access_key_id") + secretAccessKey <- props.get("aws_secret_access_key") + yield props.get("aws_session_token") match { + case Some(sessionToken) => + AwsSessionCredentials( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + sessionToken = sessionToken, + validateCredentials = false, + providerName = None, + accountId = None, + expirationTime = None + ) + case None => + AwsBasicCredentials( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + validateCredentials = false, + providerName = None, + accountId = None, + expirationTime = None + ) + } + + private def extractCredentials(profileFile: ProfileFile): F[AwsCredentials] = + for + profile <- ev.fromOption( + profileFile.profiles.get(profileName), + new SdkClientException("") + ) + credentials <- ev.fromOption( + parseStaticCredentials(profile.properties), + new SdkClientException("") + ) + yield credentials + + private def updateCredentials(profileFile: ProfileFile): F[AwsCredentials] = + semaphore.permit.use { _ => + for + cached <- cacheRef.get + credentials <- cached match + case Some((cachedFile, creds)) if cachedFile.lastModified == profileFile.lastModified => + ev.pure(creds) + case _ => + for + creds <- extractCredentials(profileFile) + _ <- cacheRef.set(Some((profileFile, creds))) + yield creds + yield credentials + } + +object ProfileCredentialsProvider: + + final case class Profile( + name: String, + properties: Map[String, String] + ) + + final case class ProfileFile( + profiles: Map[String, Profile], + lastModified: Instant + ) + + def default[F[_]: SystemProperties: Files: Concurrent]( + profileName: String = "default" + ): F[ProfileCredentialsProvider[F]] = + for + cacheRef <- Ref.of[F, Option[(ProfileFile, AwsCredentials)]](None) + semaphore <- Semaphore[F](1) + yield new ProfileCredentialsProvider[F](profileName, cacheRef, semaphore) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala new file mode 100644 index 000000000..24d680b5e --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProvider.scala @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.MonadThrow + +import cats.effect.std.SystemProperties + +import ldbc.amazon.auth.credentials.internal.SystemSettingsCredentialsProvider +import ldbc.amazon.useragent.BusinessMetricFeatureId +import ldbc.amazon.util.SdkSystemSetting + +/** + * AWS credentials provider that loads credentials from JVM system properties. + * + * This provider looks for AWS credentials in the following system properties: + * - aws.accessKeyId: The AWS access key ID + * - aws.secretAccessKey: The AWS secret access key + * - aws.sessionToken: Optional session token for temporary credentials + * + * This provider only checks system properties and does not fallback to environment variables. + * Customers can use this provider when they want to explicitly source credentials from + * JVM system properties only. + * + * @tparam F The effect type that supports system property access and error handling + */ +final class SystemPropertyCredentialsProvider[F[_]: SystemProperties: MonadThrow] + extends SystemSettingsCredentialsProvider[F]: + + /** + * Loads a setting value from JVM system properties only. + * + * This implementation specifically looks at system properties and does not check + * environment variables, allowing customers to specify a credentials provider that only + * uses system properties. + * + * @param setting The system setting to load + * @return An effect containing the optional setting value from system properties + */ + override def loadSetting(setting: SdkSystemSetting): F[Option[String]] = + SystemProperties[F].get(setting.systemProperty) + + /** + * Returns the provider identifier for business metrics tracking. + * + * @return The business metric feature ID for JVM system property credentials + */ + override def provider: String = BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala new file mode 100644 index 000000000..9b7d3a017 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProvider.scala @@ -0,0 +1,296 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties, UUIDGen } + +import fs2.io.file.{ Files, Path } + +import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils +import ldbc.amazon.client.HttpClient +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.util.SdkSystemSetting + +/** + * [[AwsCredentialsProvider]] implementation that loads credentials from AWS Web Identity Token File. + * + * This provider is primarily used in environments that support OIDC (OpenID Connect) authentication, + * such as Kubernetes with service accounts, EKS with IAM Roles for Service Accounts (IRSA), + * or other containerized environments that provide JWT tokens for AWS authentication. + * + * The provider reads configuration from: + * - Environment variables: AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN, AWS_ROLE_SESSION_NAME + * - System properties: aws.webIdentityTokenFile, aws.roleArn, aws.roleSessionName + * - AWS config file: web_identity_token_file, role_arn, role_session_name (profile-specific) + * + * Authentication flow: + * 1. Read JWT token from the specified file path + * 2. Use the token to assume the specified IAM role via STS AssumeRoleWithWebIdentity + * 3. Return temporary AWS credentials (access key, secret key, session token) + * + * The JWT token file is typically mounted by container orchestration systems and rotated automatically. + * + * Example Kubernetes configuration with IRSA: + * {{{ + * apiVersion: v1 + * kind: ServiceAccount + * metadata: + * annotations: + * eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/my-role + * }}} + * + * This will automatically set: + * - AWS_ROLE_ARN=arn:aws:iam::123456789012:role/my-role + * - AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * @param webIdentityUtils Web Identity credentials utility for STS operations + * @tparam F The effect type + */ +final class WebIdentityTokenFileCredentialsProvider[F[_]: Env: SystemProperties: Concurrent]( + webIdentityUtils: WebIdentityCredentialsUtils[F] +) extends AwsCredentialsProvider[F]: + + /** + * Resolves AWS credentials using Web Identity Token authentication. + * + * This method implements the OIDC (OpenID Connect) authentication flow used in + * containerized environments such as Kubernetes with IRSA (IAM Roles for Service Accounts) + * or other OIDC-enabled platforms. + * + * The resolution process: + * 1. Load Web Identity configuration from environment variables/system properties + * 2. Validate that required configuration (token file path and role ARN) is present + * 3. Use STS AssumeRoleWithWebIdentity to exchange the JWT token for AWS credentials + * 4. Return the temporary AWS credentials with session token + * + * This provider is commonly used in: + * - Amazon EKS with IAM Roles for Service Accounts (IRSA) + * - Self-managed Kubernetes clusters with OIDC providers + * - CI/CD environments with OIDC authentication + * - Serverless platforms with Web Identity token support + * + * @return F[AwsCredentials] The resolved AWS credentials with temporary access key, secret key, and session token + * @throws SdkClientException if Web Identity configuration is missing or credential exchange fails + */ + override def resolveCredentials(): F[AwsCredentials] = + for + config <- loadWebIdentityConfig() + credentials <- config match { + case None => + Concurrent[F].raiseError( + new SdkClientException( + "Unable to load Web Identity Token credentials. " + + "Required environment variables (AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN) or " + + "system properties (aws.webIdentityTokenFile, aws.roleArn) are not set." + ) + ) + case Some(webIdentityConfig) => + webIdentityUtils.assumeRoleWithWebIdentity(webIdentityConfig) + } + yield credentials + + /** + * Loads Web Identity Token credential configuration from environment variables and system properties. + * + * This method attempts to load the required configuration for Web Identity Token authentication + * by checking environment variables first, then falling back to system properties. + * + * Required configuration: + * - Token file path: Path to the JWT token file (typically mounted by Kubernetes) + * - Role ARN: The ARN of the IAM role to assume using the Web Identity token + * + * Optional configuration: + * - Role session name: Custom session name for the assumed role session + * + * Configuration sources (in priority order): + * 1. Environment variables: AWS_WEB_IDENTITY_TOKEN_FILE, AWS_ROLE_ARN, AWS_ROLE_SESSION_NAME + * 2. System properties: aws.webIdentityTokenFile, aws.roleArn, aws.roleSessionName + * + * @return F[Option[WebIdentityTokenCredentialProperties]] Web Identity configuration if both token file and role ARN are available, None otherwise + */ + private def loadWebIdentityConfig(): F[Option[WebIdentityTokenCredentialProperties]] = + for + tokenFilePath <- loadTokenFilePath() + roleArn <- loadRoleArn() + roleSessionName <- loadRoleSessionName() + yield (tokenFilePath, roleArn) match { + case (Some(tokenFile), Some(arn)) => + Some( + WebIdentityTokenCredentialProperties( + webIdentityTokenFile = Path(tokenFile), + roleArn = arn, + roleSessionName = roleSessionName + ) + ) + case _ => None + } + + /** + * Loads the Web Identity Token file path from configuration sources. + * + * This method retrieves the file system path to the JWT token file used for + * Web Identity authentication. The token file is typically mounted by container + * orchestration systems and contains a JWT token signed by an OIDC provider. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_WEB_IDENTITY_TOKEN_FILE + * 2. System property: aws.webIdentityTokenFile + * + * Common token file paths in Kubernetes environments: + * - /var/run/secrets/eks.amazonaws.com/serviceaccount/token (EKS IRSA) + * - /var/run/secrets/kubernetes.io/serviceaccount/token (Standard Kubernetes SA token) + * + * @return F[Option[String]] The token file path if configured, None otherwise + */ + private def loadTokenFilePath(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_WEB_IDENTITY_TOKEN_FILE") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + + /** + * Loads the IAM role ARN for Web Identity Token authentication. + * + * This method retrieves the Amazon Resource Name (ARN) of the IAM role that should + * be assumed using the Web Identity token. The role must be configured with a trust + * policy that allows the OIDC provider to assume it. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_ROLE_ARN + * 2. System property: aws.roleArn + * + * Example role ARN format: + * arn:aws:iam::123456789012:role/EKS-Pod-Identity-Role + * + * The IAM role must have a trust policy similar to: + * ```json + * { + * "Version": "2012-10-17", + * "Statement": [ + * { + * "Effect": "Allow", + * "Principal": { + * "Federated": "arn:aws:iam::ACCOUNT_ID:oidc-provider/OIDC_PROVIDER_URL" + * }, + * "Action": "sts:AssumeRoleWithWebIdentity" + * } + * ] + * } + * ``` + * + * @return F[Option[String]] The IAM role ARN if configured, None otherwise + */ + private def loadRoleArn(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_ROLE_ARN") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + + /** + * Loads the optional role session name for Web Identity Token authentication. + * + * This method retrieves an optional session name that will be used to identify + * the assumed role session in AWS CloudTrail logs and other AWS services. + * If not provided, a default session name will be generated automatically. + * + * Configuration sources (in priority order): + * 1. Environment variable: AWS_ROLE_SESSION_NAME + * 2. System property: aws.roleSessionName + * + * Session name requirements: + * - Must be 2-64 characters long + * - Can contain letters, numbers, and the characters +=,.@- + * - Cannot contain spaces + * + * The session name appears in: + * - AWS CloudTrail logs as the "roleSessionName" field + * - AWS STS GetCallerIdentity response + * - IAM policy evaluation context for condition keys + * + * Example session names: + * - "kubernetes-pod-web-app" + * - "ci-cd-pipeline-123" + * - "user-workload-session" + * + * @return F[Option[String]] The role session name if configured, None otherwise + */ + private def loadRoleSessionName(): F[Option[String]] = + for + envValue <- Env[F].get("AWS_ROLE_SESSION_NAME") + sysPropValue <- SystemProperties[F].get(SdkSystemSetting.AWS_ROLE_SESSION_NAME.systemProperty) + yield envValue.orElse(sysPropValue).map(_.trim).filter(_.nonEmpty) + +/** + * Configuration properties for Web Identity Token credentials. + * + * @param webIdentityTokenFile Path to the JWT token file + * @param roleArn The ARN of the IAM role to assume + * @param roleSessionName Optional session name for the assumed role session + */ +case class WebIdentityTokenCredentialProperties( + webIdentityTokenFile: Path, + roleArn: String, + roleSessionName: Option[String] +) + +object WebIdentityTokenFileCredentialsProvider: + + /** + * Creates a new Web Identity Token File credentials provider with custom HTTP client. + * + * @param httpClient The HTTP client for STS requests + * @param region The AWS region for STS endpoint + * @tparam F The effect type + * @return A new WebIdentityTokenFileCredentialsProvider instance + */ + def default[F[_]: Env: SystemProperties: Files: UUIDGen: Concurrent]( + httpClient: HttpClient[F], + region: String = "us-east-1" + ): WebIdentityTokenFileCredentialsProvider[F] = + val webIdentityUtils = WebIdentityCredentialsUtils.default[F](region, httpClient) + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) + + /** + * Creates a new Web Identity Token File credentials provider with custom WebIdentityCredentialsUtils. + * + * @param webIdentityUtils Custom Web Identity credentials utility + * @tparam F The effect type + * @return A new WebIdentityTokenFileCredentialsProvider instance + */ + def create[F[_]: Env: SystemProperties: Concurrent]( + webIdentityUtils: WebIdentityCredentialsUtils[F] + ): WebIdentityTokenFileCredentialsProvider[F] = + new WebIdentityTokenFileCredentialsProvider[F](webIdentityUtils) + + /** + * Checks if Web Identity Token authentication is available by verifying + * the presence of required environment variables or system properties. + * + * @tparam F The effect type + * @return true if Web Identity Token authentication is properly configured + */ + def isAvailable[F[_]: Env: SystemProperties: Concurrent](): F[Boolean] = + for + tokenFile <- Env[F] + .get("AWS_WEB_IDENTITY_TOKEN_FILE") + .flatMap(envValue => + SystemProperties[F] + .get(SdkSystemSetting.AWS_WEB_IDENTITY_TOKEN_FILE.systemProperty) + .map(envValue.orElse(_)) + ) + roleArn <- Env[F] + .get("AWS_ROLE_ARN") + .flatMap(envValue => + SystemProperties[F] + .get(SdkSystemSetting.AWS_ROLE_ARN.systemProperty) + .map(envValue.orElse(_)) + ) + yield tokenFile.exists(_.trim.nonEmpty) && roleArn.exists(_.trim.nonEmpty) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala new file mode 100644 index 000000000..b50795b6b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/SystemSettingsCredentialsProvider.scala @@ -0,0 +1,174 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials.internal + +import cats.syntax.all.* +import cats.MonadThrow + +import ldbc.amazon.auth.credentials.* +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.* +import ldbc.amazon.util.SdkSystemSetting + +/** + * Base trait for AWS credential providers that load credentials from system settings. + * + * This trait provides a common foundation for credential providers that read AWS credentials + * from system-level configuration sources such as environment variables and Java system properties. + * It implements the standard AWS credential loading pattern with support for both basic + * credentials (access key + secret key) and session credentials (+ session token). + * + * The trait uses the AWS SDK system setting definitions to ensure consistency with + * AWS SDK naming conventions and behavior. It supports the following credential types: + * + * Basic credentials: + * - Access Key ID (AWS_ACCESS_KEY_ID / aws.accessKeyId) + * - Secret Access Key (AWS_SECRET_ACCESS_KEY / aws.secretAccessKey) + * + * Session credentials (temporary credentials): + * - Access Key ID (AWS_ACCESS_KEY_ID / aws.accessKeyId) + * - Secret Access Key (AWS_SECRET_ACCESS_KEY / aws.secretAccessKey) + * - Session Token (AWS_SESSION_TOKEN / aws.sessionToken) + * + * Optional metadata: + * - Account ID (AWS_ACCOUNT_ID / aws.accountId) + * + * Implementing classes must provide: + * 1. loadSetting method to read from their specific configuration source + * 2. provider property to identify the credential source for metrics + * + * Error handling: + * - Missing access key or secret key results in SdkClientException + * - Missing session token is acceptable (falls back to basic credentials) + * - Account ID is optional for all credential types + * + * @tparam F The effect type that supports error handling via MonadThrow + */ +trait SystemSettingsCredentialsProvider[F[_]](using ev: MonadThrow[F]) extends AwsCredentialsProvider[F]: + + /** + * Resolves AWS credentials from system settings (environment variables or system properties). + * + * This method implements the standard AWS credential loading logic used by both + * environment variable and system property credential providers. It attempts to + * load all required and optional credential components from the configured source. + * + * Credential loading process: + * 1. Load access key ID (required) + * 2. Load secret access key (required) + * 3. Load session token (optional) + * 4. Load account ID (optional) + * 5. Validate that required credentials are present + * 6. Determine credential type based on session token presence + * 7. Create appropriate credential instance (Basic or Session) + * + * Credential validation: + * - Access key ID and secret access key are required + * - Values are trimmed and must be non-empty after trimming + * - Missing required credentials result in SdkClientException + * - Session token and account ID are optional + * + * Credential types returned: + * - AwsBasicCredentials: When only access key and secret key are provided + * - AwsSessionCredentials: When session token is also provided + * + * Both credential types include provider metadata for AWS usage metrics and debugging: + * - providerName: Identifies the specific credential source + * - accountId: AWS account ID if available + * - validateCredentials: Set to false to skip validation + * - expirationTime: None for system setting credentials (no expiration) + * + * @return F[AwsCredentials] The resolved AWS credentials (Basic or Session type) + * @throws SdkClientException if required credentials (access key or secret key) are missing + */ + override def resolveCredentials(): F[AwsCredentials] = + for + accessKeyOpt <- loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map(_.map(_.trim)) + secretKeyOpt <- loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY).map(_.map(_.trim)) + sessionTokenOpt <- loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN).map(_.map(_.trim)) + accountId <- loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID).map(_.map(_.trim)) + accessKey <- ev.fromOption( + accessKeyOpt, + new SdkClientException( + s"Unable to load credentials from system settings. Access key must be specified either via environment variable (${ SdkSystemSetting.AWS_ACCESS_KEY_ID }) or system property (${ SdkSystemSetting.AWS_ACCESS_KEY_ID.systemProperty })." + ) + ) + secretKey <- ev.fromOption( + secretKeyOpt, + new SdkClientException( + s"Unable to load credentials from system settings. Secret key must be specified either via environment variable (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY }) or system property (${ SdkSystemSetting.AWS_SECRET_ACCESS_KEY.systemProperty })." + ) + ) + yield sessionTokenOpt match { + case None => + AwsBasicCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None + ) + case Some(sessionToken) => + AwsSessionCredentials( + accessKeyId = accessKey, + secretAccessKey = secretKey, + sessionToken = sessionToken, + validateCredentials = false, + providerName = Some(provider), + accountId = accountId, + expirationTime = None + ) + } + + /** + * Loads a system setting value from the concrete configuration source. + * + * This abstract method must be implemented by concrete credential providers + * to specify how to read configuration values from their specific source + * (environment variables, system properties, etc.). + * + * The implementation should: + * 1. Read the value from the appropriate configuration source + * 2. Return Some(value) if the setting exists and has a value + * 3. Return None if the setting does not exist or has no value + * 4. Handle any source-specific errors appropriately + * + * Example implementations: + * - Environment variables: Read from System.getenv() or effect-based environment access + * - System properties: Read from System.getProperty() or effect-based property access + * - Configuration files: Parse and read from file-based configuration + * + * The SdkSystemSetting parameter provides both the environment variable name + * and system property name for the setting, allowing implementations to choose + * the appropriate source or implement fallback logic. + * + * @param setting The AWS SDK system setting definition containing environment variable and system property names + * @return F[Option[String]] The setting value if available, None otherwise + */ + def loadSetting(setting: SdkSystemSetting): F[Option[String]] + + /** + * Identifier string for this credential provider used in AWS usage metrics and logging. + * + * This string is included in AWS service requests to track credential provider usage + * and appears in AWS CloudTrail logs for debugging purposes. It should be a short, + * descriptive identifier that uniquely identifies the credential source. + * + * Common provider identifiers: + * - "Environment": For environment variable credentials + * - "SystemProperty": For Java system property credentials + * - "Profile": For AWS credentials file profiles + * - "Container": For ECS/EKS container credentials + * - "InstanceProfile": For EC2 instance profile credentials + * + * This value is used for AWS business metrics and helps AWS understand + * how different credential providers are being used across AWS SDKs. + * + * @return String identifier for this credential provider + */ + def provider: String diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala new file mode 100644 index 000000000..28621a53d --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtils.scala @@ -0,0 +1,171 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials.internal + +import cats.syntax.all.* +import cats.MonadThrow + +import cats.effect.std.UUIDGen +import cats.effect.Concurrent + +import fs2.io.file.{ Files, Path } + +import ldbc.amazon.auth.credentials.* +import ldbc.amazon.client.{ HttpClient, StsClient } +import ldbc.amazon.exception.{ InvalidTokenException, TokenFileNotFoundException } +import ldbc.amazon.identity.AwsCredentials + +/** + * Trait for handling Web Identity Token credentials and STS AssumeRoleWithWebIdentity operations. + * + * This trait provides functionality to: + * - Read JWT tokens from file system + * - Parse STS AssumeRoleWithWebIdentity responses + * - Generate AWS credentials from temporary session tokens + */ +trait WebIdentityCredentialsUtils[F[_]]: + + /** + * Assumes an IAM role using Web Identity Token and returns AWS credentials. + * + * This method performs the STS AssumeRoleWithWebIdentity operation, exchanging + * a Web Identity Token (JWT) for temporary AWS credentials. + * + * @param config The Web Identity Token configuration + * @return AWS credentials with session token + */ + def assumeRoleWithWebIdentity( + config: WebIdentityTokenCredentialProperties + ): F[AwsCredentials] + +object WebIdentityCredentialsUtils: + + /** + * Private implementation of WebIdentityCredentialsUtils. + * + * This implementation handles the complete Web Identity Token flow: + * 1. Reading JWT token from file system + * 2. Validating the token format + * 3. Calling STS AssumeRoleWithWebIdentity + * 4. Converting the STS response to AWS credentials + * + * @param stsClient The STS client to use for AssumeRoleWithWebIdentity operations + * @tparam F The effect type that supports file operations and concurrency + */ + private case class Impl[F[_]: Files: Concurrent]( + stsClient: StsClient[F] + ) extends WebIdentityCredentialsUtils[F]: + + override def assumeRoleWithWebIdentity( + config: WebIdentityTokenCredentialProperties + ): F[AwsCredentials] = + for + token <- readTokenFromFile(config.webIdentityTokenFile) + _ <- validateToken(token) + stsRequest = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = config.roleArn, + webIdentityToken = token, + roleSessionName = config.roleSessionName + ) + stsResponse <- stsClient.assumeRoleWithWebIdentity(stsRequest) + credentials = convertStsResponseToCredentials(stsResponse) + yield credentials + + /** + * Reads JWT token from the specified file path. + * + * @param tokenFilePath Path to the JWT token file + * @return The JWT token content as a string + */ + private def readTokenFromFile(tokenFilePath: Path): F[String] = + for + exists <- Files[F].exists(tokenFilePath) + _ <- Concurrent[F].raiseUnless(exists)( + new TokenFileNotFoundException(s"Web Identity Token file not found: $tokenFilePath") + ) + token <- Files[F].readUtf8(tokenFilePath).compile.string.map(_.trim) + _ <- Concurrent[F].raiseWhen(token.isEmpty)( + new InvalidTokenException(s"Web Identity Token file is empty: $tokenFilePath") + ) + yield token + + /** + * Validates the JWT token format. + * + * This is a basic validation to ensure the token has a JWT-like structure. + * A full implementation would include signature verification and claims validation. + * + * @param token The JWT token to validate + * @return Unit if token is valid + */ + private def validateToken(token: String): F[Unit] = + MonadThrow[F].fromEither { + // Basic JWT format validation (header.payload.signature) + val parts = token.split("\\.") + if parts.length != 3 then { + Left(new InvalidTokenException(s"Invalid JWT token format. Expected 3 parts, got ${ parts.length }")) + } else if parts.exists(_.isEmpty) then { + Left(new InvalidTokenException("JWT token contains empty parts")) + } else { + Right(()) + } + } + + /** + * Converts STS response to AWS credentials. + * + * @param stsResponse The STS AssumeRoleWithWebIdentity response + * @return AWS session credentials + */ + private def convertStsResponseToCredentials( + stsResponse: StsClient.AssumeRoleWithWebIdentityResponse + ): AwsCredentials = + AwsSessionCredentials( + accessKeyId = stsResponse.accessKeyId, + secretAccessKey = stsResponse.secretAccessKey, + sessionToken = stsResponse.sessionToken, + validateCredentials = false, + providerName = None, + accountId = extractAccountIdFromArn(stsResponse.assumedRoleArn), + expirationTime = Some(stsResponse.expiration) + ) + + /** + * Extracts AWS account ID from ARN. + * + * @param arn The AWS ARN (e.g., arn:aws:sts::123456789012:assumed-role/MyRole/MySession) + * @return Optional account ID + */ + private def extractAccountIdFromArn(arn: String): Option[String] = + // ARN format: arn:aws:sts::ACCOUNT_ID:assumed-role/ROLE_NAME/SESSION_NAME + val arnParts = arn.split(":") + arnParts.lift(4) + + /** + * Creates a default implementation of WebIdentityCredentialsUtils. + * + * @param region The AWS region for STS endpoint + * @param httpClient HTTP client for making requests + * @tparam F The effect type + * @return A WebIdentityCredentialsUtils instance + */ + def default[F[_]: Files: UUIDGen: Concurrent]( + region: String, + httpClient: HttpClient[F] + ): WebIdentityCredentialsUtils[F] = + val stsClient = StsClient.build[F](s"https://sts.$region.amazonaws.com/", httpClient) + Impl[F](stsClient) + + /** + * Creates a WebIdentityCredentialsUtils with custom StsClient. + * + * @param stsClient Custom STS client implementation + * @tparam F The effect type + * @return A WebIdentityCredentialsUtils instance + */ + def create[F[_]: Files: Concurrent](stsClient: StsClient[F]): WebIdentityCredentialsUtils[F] = + Impl[F](stsClient) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala new file mode 100644 index 000000000..9fea79bb5 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/AuthTokenGenerator.scala @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.token + +import ldbc.amazon.identity.AwsCredentials + +/** + * A trait for generating authentication tokens using AWS credentials. + * + * This trait defines the contract for creating authentication tokens that can be used + * for various AWS services that support IAM-based authentication. The generated tokens + * provide temporary access based on IAM credentials and policies, eliminating the need + * to store long-term credentials in application code. + * + * @tparam F The effect type that wraps the token generation operations + */ +trait AuthTokenGenerator[F[_]]: + + /** + * Generates an authentication token using the provided AWS credentials. + * + * The generated token provides temporary access to AWS services that support + * IAM-based authentication. The token is typically valid for a limited time + * period and provides access based on the IAM permissions associated with + * the provided credentials. + * + * @param credentials The AWS credentials containing access key ID, secret access key, + * and optional session token for temporary credentials + * @return The authentication token wrapped in the effect type F, ready to be used + * for service authentication + */ + def generateToken(credentials: AwsCredentials): F[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala new file mode 100644 index 000000000..3d1fb8c8e --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGenerator.scala @@ -0,0 +1,311 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.token + +import java.net.URLEncoder +import java.nio.charset.StandardCharsets +import java.time.{ Instant, ZoneOffset } +import java.time.format.DateTimeFormatter + +import scala.concurrent.duration.FiniteDuration + +import cats.syntax.all.* + +import cats.effect.kernel.{ Clock, Sync } + +import fs2.{ Chunk, Stream } +import fs2.hashing.{ HashAlgorithm, Hashing } + +import ldbc.amazon.auth.credentials.* +import ldbc.amazon.identity.AwsCredentials + +/** + * Implementation of AuthTokenGenerator for Amazon RDS IAM database authentication. + * + * This class generates authentication tokens that can be used to connect to Amazon RDS + * database instances using IAM credentials instead of traditional database passwords. + * The generated tokens are signed using AWS Signature Version 4 and are valid for + * 15 minutes from the time of generation. + * + * The tokens enable secure, temporary access to RDS databases based on IAM policies + * and eliminate the need to store database passwords in application code. + * + * @param hostname The RDS instance hostname or endpoint + * @param port The database port number (typically 3306 for MySQL) + * @param username The database username for which to generate the token + * @param region The AWS region where the RDS instance is located + * @param clock Clock instance for timestamp generation + * @tparam F The effect type that wraps the token generation operations + * + * @example + * {{{ + * import cats.effect.IO + * import ldbc.amazon.auth.token.RdsIamAuthTokenGenerator + * import ldbc.amazon.identity.AwsCredentials + * + * val generator = new RdsIamAuthTokenGenerator[IO]( + * hostname = "my-db-instance.region.rds.amazonaws.com", + * port = 3306, + * username = "db_user", + * region = "us-east-1" + * ) + * + * val credentials = AwsCredentials( + * accessKeyId = "AKIAIOSFODNN7EXAMPLE", + * secretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + * sessionToken = None + * ) + * + * val token: IO[String] = generator.generateToken(credentials) + * }}} + */ +class RdsIamAuthTokenGenerator[F[_]: Hashing: Sync]( + hostname: String, + port: Int, + username: String, + region: String +)(using clock: Clock[F]) + extends AuthTokenGenerator[F]: + + /** AWS Signature Version 4 algorithm identifier */ + private val ALGORITHM = "AWS4-HMAC-SHA256" + + /** AWS service identifier for RDS database connections */ + private val SERVICE = "rds-db" + + /** Token expiration time in seconds (15 minutes) */ + private val EXPIRES_SECONDS = 900 + + /** AWS4 request terminator string */ + private val TERMINATOR = "aws4_request" + + /** + * Generates a signed authentication token for RDS IAM database authentication. + * + * This method creates a presigned URL-style token that can be used as a password + * when connecting to an RDS database instance. The token is signed using AWS + * Signature Version 4 and includes the necessary IAM credentials and metadata. + * + * The generated token follows the format required by MySQL's mysql_clear_password + * authentication plugin and can be used with SSL/TLS connections to RDS instances. + * + * @param credentials AWS credentials containing access key, secret key, and optional session token + * @return The signed authentication token that can be used as a database password + */ + override def generateToken(credentials: AwsCredentials): F[String] = + for + now <- clock.realTime + dateTime = formatDateTime(now) + date = dateTime.substring(0, 8) + credentialScope = s"$date/$region/$SERVICE/$TERMINATOR" + credential = s"${ credentials.accessKeyId }/$credentialScope" + queryParams = credentials match { + case c: AwsBasicCredentials => buildQueryParams(credential, dateTime, None, username) + case c: AwsSessionCredentials => + buildQueryParams(credential, dateTime, Some(c.sessionToken), username) + } + canonicalRequest = buildCanonicalRequest(s"$hostname:$port", queryParams) + canonicalRequestHash <- sha256Hex(canonicalRequest) + stringToSign = buildStringToSign(dateTime, credentialScope, canonicalRequestHash) + signature <- calculateSignature(credentials.secretAccessKey, date, region, stringToSign) + yield s"$hostname:$port/?$queryParams&X-Amz-Signature=$signature" + + /** + * Formats an FiniteDuration to ISO 8601 basic format string required by AWS Signature Version 4. + * + * Converts a timestamp to the format "yyyyMMddTHHmmssZ" in UTC timezone, + * which is required for AWS authentication requests. + * + * @param duration Duration to add to the current time + * @return Formatted datetime string in AWS SigV4 format (e.g., "20230101T120000Z") + */ + private def formatDateTime(duration: FiniteDuration): String = + DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(Instant.EPOCH.plusNanos(duration.toNanos)) + + /** + * Builds the query parameters for the RDS authentication request. + * + * Creates a URL-encoded query string containing all the required AWS Signature Version 4 + * parameters for RDS IAM authentication. The parameters are sorted alphabetically + * as required by the AWS signing process. + * + * @param credential The AWS credential string in format "accessKeyId/scope" + * @param dateTime The formatted timestamp for the request + * @param sessionToken The AWS session token (for temporary credentials) + * @param username The database username to authenticate as + * @return URL-encoded query string with all required authentication parameters + */ + private def buildQueryParams( + credential: String, + dateTime: String, + sessionToken: Option[String], + username: String + ): String = + val baseParams = List( + "Action" -> "connect", + "DBUser" -> username, + "X-Amz-Algorithm" -> ALGORITHM, + "X-Amz-Credential" -> credential, + "X-Amz-Date" -> dateTime, + "X-Amz-Expires" -> EXPIRES_SECONDS.toString, + "X-Amz-SignedHeaders" -> "host" + ) + + val params = sessionToken match + case Some(token) => baseParams :+ ("X-Amz-Security-Token" -> token) + case None => baseParams + + params + .sortBy(_._1) + .map { case (k, v) => s"${ urlEncode(k) }=${ urlEncode(v) }" } + .mkString("&") + + /** + * Constructs the canonical request string according to AWS Signature Version 4 specification. + * + * The canonical request is a standardized representation of the HTTP request that will be + * used for signature calculation. It includes the HTTP method, URI path, query string, + * headers, signed headers list, and payload hash. + * + * @param host The database host including port (e.g., "host.rds.amazonaws.com:3306") + * @param queryString The URL-encoded query parameters + * @return The canonical request string formatted according to AWS SigV4 requirements + */ + private def buildCanonicalRequest(host: String, queryString: String): String = + val method = "GET" + val canonicalUri = "/" + val canonicalHeaders = s"host:$host\n" + val signedHeaders = "host" + val payloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + s"$method\n$canonicalUri\n$queryString\n$canonicalHeaders\n$signedHeaders\n$payloadHash" + + /** + * Converts a byte array to lowercase hexadecimal string representation. + * + * Used for converting hash outputs to the hex format required by AWS signatures. + * Each byte is formatted as a two-character lowercase hex string. + * + * @param bytes The byte array to convert + * @return Lowercase hexadecimal string representation + */ + private def bytesToHex(bytes: Array[Byte]): String = + bytes.map(b => "%02x".format(b & 0xff)).mkString + + /** + * Computes the SHA-256 hash of a string and returns it as a lowercase hex string. + * + * Uses fs2's streaming hash functionality to compute the SHA-256 hash of the input + * string encoded as UTF-8 bytes. The result is converted to lowercase hexadecimal + * format as required by AWS Signature Version 4. + * + * @param data The string to hash + * @return The SHA-256 hash as a lowercase hexadecimal string + */ + private def sha256Hex(data: String): F[String] = + Stream + .chunk(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) + .through(Hashing[F].hash(HashAlgorithm.SHA256)) + .compile + .lastOrError + .map(hash => bytesToHex(hash.bytes.toArray)) + + /** + * Creates the "String to Sign" for AWS Signature Version 4. + * + * The string to sign is a formatted string that combines the algorithm identifier, + * timestamp, credential scope, and canonical request hash. This string will be + * signed using the AWS signing key to produce the final signature. + * + * @param dateTime The ISO 8601 formatted timestamp + * @param credentialScope The credential scope (date/region/service/terminator) + * @param canonicalRequestHash The SHA-256 hash of the canonical request + * @return The formatted string to be signed + */ + private def buildStringToSign( + dateTime: String, + credentialScope: String, + canonicalRequestHash: String + ): String = + s"$ALGORITHM\n$dateTime\n$credentialScope\n$canonicalRequestHash" + + /** + * Computes HMAC-SHA256 hash using the provided key and data. + * + * Uses fs2's streaming HMAC functionality to compute the HMAC-SHA256 of the input + * data using the provided key. This is a core cryptographic operation used in + * AWS Signature Version 4 key derivation and final signature calculation. + * + * @param key The cryptographic key as byte array + * @param data The data to authenticate as string (will be UTF-8 encoded) + * @return The HMAC-SHA256 result as byte array + */ + private def hmacSha256(key: Array[Byte], data: String): F[Array[Byte]] = + Hashing[F] + .hmac(HashAlgorithm.SHA256, Chunk.array(key)) + .use { hmac => + for + _ <- hmac.update(Chunk.array(data.getBytes(StandardCharsets.UTF_8))) + hash <- hmac.hash + yield hash.bytes.toArray + } + + /** + * Derives the AWS Signature Version 4 signing key and calculates the final signature. + * + * Implements the AWS SigV4 key derivation algorithm by computing a series of + * HMAC-SHA256 operations to derive the signing key, then signs the string-to-sign + * with that key. The process follows AWS specifications for signature calculation. + * + * Key derivation steps: + * 1. kDate = HMAC-SHA256("AWS4" + SecretKey, Date) + * 2. kRegion = HMAC-SHA256(kDate, Region) + * 3. kService = HMAC-SHA256(kRegion, Service) + * 4. kSigning = HMAC-SHA256(kService, "aws4_request") + * 5. signature = HMAC-SHA256(kSigning, StringToSign) + * + * @param secretKey The AWS secret access key + * @param date The date in YYYYMMDD format + * @param region The AWS region + * @param stringToSign The formatted string to be signed + * @return The final signature as lowercase hexadecimal string + */ + private def calculateSignature( + secretKey: String, + date: String, + region: String, + stringToSign: String + ): F[String] = + for + kDate <- hmacSha256(s"AWS4$secretKey".getBytes(StandardCharsets.UTF_8), date) + kRegion <- hmacSha256(kDate, region) + kService <- hmacSha256(kRegion, SERVICE) + kSigning <- hmacSha256(kService, TERMINATOR) + sig <- hmacSha256(kSigning, stringToSign) + yield bytesToHex(sig) + + /** + * URL encodes a string according to AWS requirements. + * + * Performs URL encoding with specific character replacements required by AWS: + * - Spaces are encoded as %20 (not +) + * - Asterisks (*) are encoded as %2A + * - Tildes (~) remain unencoded + * + * This encoding is required for query parameter values in AWS authentication. + * + * @param value The string to URL encode + * @return The URL encoded string with AWS-specific character handling + */ + private def urlEncode(value: String): String = + URLEncoder + .encode(value, "UTF-8") + .replace("+", "%20") + .replace("*", "%2A") + .replace("%7E", "~") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala new file mode 100644 index 000000000..56c89fea3 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/BasedHttpClient.scala @@ -0,0 +1,187 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import com.comcast.ip4s.* + +import cats.syntax.all.* + +import cats.effect.* +import cats.effect.syntax.all.* + +import fs2.* +import fs2.io.net.* + +import ldbc.amazon.exception.* + +private[client] trait BasedHttpClient[F[_]: Async] extends HttpClient[F]: + + def connectTimeout: Duration + def readTimeout: Duration + + private def isHttps(uri: URI): Boolean = + Option(uri.getScheme).exists(_.toLowerCase == "https") + + private def getDefaultPort(uri: URI): Int = + if uri.getPort > 0 then uri.getPort + else if isHttps(uri) then 443 + else 80 + + private def validateScheme(uri: URI): F[Unit] = + Option(uri.getScheme) match + case None => Async[F].raiseError(new SdkClientException("URI scheme is required")) + case Some(scheme) if scheme.toLowerCase == "http" => + // Log warning for HTTP usage, but allow it for non-sensitive endpoints + Async[F].unit + case Some(scheme) if scheme.toLowerCase == "https" => Async[F].unit + case Some(unsupported) => + Async[F].raiseError( + new SdkClientException(s"Unsupported URI scheme: $unsupported. Only http and https are supported.") + ) + + private def validateSecurityRequirements(uri: URI): F[Unit] = + // AWS endpoints should always use HTTPS + if Option(uri.getHost).exists(_.contains(".amazonaws.com")) && !isHttps(uri) then + Async[F].raiseError( + new SdkClientException(s"AWS endpoints require HTTPS. Attempted to use: ${ uri.getScheme }://${ uri.getHost }") + ) + else Async[F].unit + + def createSocket(address: SocketAddress[Host], isSecure: Boolean, host: String): Resource[F, Socket[F]] + + override def get(uri: URI, headers: Map[String, String]): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "GET", path, headers, None) + yield response + + override def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "PUT", path, headers, Some(body)) + yield response + + override def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] = + for + _ <- validateScheme(uri) + _ <- validateSecurityRequirements(uri) + host = uri.getHost + port = getDefaultPort(uri) + isSecure = isHttps(uri) + path = Option(uri.getPath).filter(_.nonEmpty).getOrElse("/") + + Option(uri.getQuery).map("?" + _).getOrElse("") + address <- resolveAddress(host, port) + response <- makeRequest(address, host, port, isSecure, "POST", path, headers, Some(body)) + yield response + + private def resolveAddress(host: String, port: Int): F[SocketAddress[Host]] = + for + h <- Async[F].fromOption(Host.fromString(host), new SdkClientException("Invalid host")) + p <- Async[F].fromOption(Port.fromInt(port), new SdkClientException("Invalid port")) + yield SocketAddress(h, p) + + private def sendRequest( + socket: Socket[F], + method: String, + host: String, + port: Int, + isSecure: Boolean, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[Unit] = + val defaultPort = if isSecure then 443 else 80 + val hostHeader = if port == defaultPort then host else s"$host:$port" + val contentHeaders = body match { + case Some(b) => Map("Content-Length" -> b.getBytes("UTF-8").length.toString) + case None => Map.empty + } + val allHeaders = headers ++ contentHeaders + ("Host" -> hostHeader) + ("Connection" -> "close") + + val requestLine = s"$method $path HTTP/1.1\r\n" + val headerLines = allHeaders.map((k, v) => s"$k: $v\r\n").mkString + val requestWithHeaders = requestLine + headerLines + "\r\n" + val fullRequest = body.map(requestWithHeaders + _).getOrElse(requestWithHeaders) + + Stream + .emit(fullRequest) + .through(text.utf8.encode) + .through(socket.writes) + .compile + .drain + + private def parseStatusLine(line: String): F[Int] = + // "HTTP/1.1 200 OK" -> 200 + line.split(" ").toList match + case _ :: code :: _ => + code.toIntOption match + case Some(c) => Async[F].pure(c) + case None => Async[F].raiseError(new CredentialsFetchError(s"Invalid status code: $code")) + case _ => Async[F].raiseError(new CredentialsFetchError(s"Invalid status line: $line")) + + private def parseHeaderLine(line: String): Option[(String, String)] = + line.split(": ", 2).toList match + case key :: value :: Nil => Some(key -> value) + case _ => None + + private def parseHttpResponse(raw: String): F[HttpResponse] = + val lines = raw.split("\r\n").toList + + lines match + case statusLine :: rest => + parseStatusLine(statusLine).flatMap: statusCode => + val (headerLines, bodyLines) = rest.span(_.nonEmpty) + val headers = headerLines.flatMap(parseHeaderLine).toMap + val body = bodyLines.drop(1).mkString("\r\n") // drop empty line + + Async[F].pure(HttpResponse(statusCode, headers, body)) + case _ => + Async[F].raiseError(CredentialsFetchError("Empty response")) + + private def receiveResponse(socket: Socket[F]): F[HttpResponse] = + socket.reads + .through(text.utf8.decode) + .compile + .string + .flatMap(parseHttpResponse) + + private def makeRequest( + address: SocketAddress[Host], + host: String, + port: Int, + isSecure: Boolean, + method: String, + path: String, + headers: Map[String, String], + body: Option[String] + ): F[HttpResponse] = + createSocket(address, isSecure, host) + .use { socket => + for + _ <- sendRequest(socket, method, host, port, isSecure, path, headers, body) + response <- receiveResponse(socket) + yield response + } + .timeout(connectTimeout + readTimeout) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala new file mode 100644 index 000000000..829db0620 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpClient.scala @@ -0,0 +1,119 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +/** + * Abstract HTTP client interface for making HTTP requests in the AWS authentication plugin. + * + * This trait defines a generic HTTP client abstraction that can be implemented for different + * effect types (F[_]) and HTTP libraries. It provides the core HTTP operations needed for + * AWS API communication including GET, PUT, and POST requests. + * + * The client is designed to work with functional effect systems like cats-effect, allowing + * for composable and type-safe HTTP operations. Each method returns the HTTP response + * wrapped in the effect type F[_]. + * + * Implementations of this trait should handle: + * - HTTP connection management and pooling + * - Request/response serialization + * - Error handling and retries + * - SSL/TLS configuration for HTTPS + * - Timeout configuration + * + * @tparam F the effect type that wraps the HTTP response (e.g., IO, Future, Task) + * + * @example {{{ + * import cats.effect.IO + * import java.net.URI + * + * // Usage with cats-effect IO + * def makeRequest(client: HttpClient[IO]): IO[HttpResponse] = { + * val uri = URI.create("https://api.amazonaws.com/endpoint") + * val headers = Map("Authorization" -> "Bearer token") + * + * for { + * response <- client.get(uri, headers) + * } yield response + * } + * }}} + * + * @see [[HttpResponse]] for the response type returned by HTTP operations + * @see [[ldbc.amazon.useragent.BusinessMetricFeatureId]] for user agent metrics tracking + */ +trait HttpClient[F[_]]: + + /** + * Performs an HTTP GET request to the specified URI. + * + * This method executes a GET request with the provided headers and returns + * the HTTP response wrapped in the effect type F[_]. GET requests should be + * idempotent and safe, typically used for retrieving data without side effects. + * + * @param uri the target URI for the GET request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://sts.amazonaws.com/") + * val headers = Map( + * "Accept" -> "application/json", + * "User-Agent" -> "ldbc-aws-plugin/1.0" + * ) + * client.get(uri, headers) + * }}} + */ + def get(uri: URI, headers: Map[String, String]): F[HttpResponse] + + /** + * Performs an HTTP PUT request to the specified URI with a request body. + * + * This method executes a PUT request with the provided headers and body content. + * PUT requests are typically used for creating or updating resources and should + * be idempotent. + * + * @param uri the target URI for the PUT request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @param body the request body content as a string + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://api.amazonaws.com/resource/123") + * val headers = Map( + * "Content-Type" -> "application/json", + * "Authorization" -> "AWS4-HMAC-SHA256 ..." + * ) + * val body = """{"key": "value"}""" + * client.put(uri, headers, body) + * }}} + */ + def put(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] + + /** + * Performs an HTTP POST request to the specified URI with a request body. + * + * This method executes a POST request with the provided headers and body content. + * POST requests are typically used for creating new resources or submitting data + * and may have side effects. + * + * @param uri the target URI for the POST request, must be a valid HTTP/HTTPS URI + * @param headers a map of HTTP headers to include in the request + * @param body the request body content as a string + * @return an effect containing the HTTP response + * + * @example {{{ + * val uri = URI.create("https://sts.amazonaws.com/") + * val headers = Map( + * "Content-Type" -> "application/x-amz-json-1.1", + * "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRole" + * ) + * val body = """{"RoleArn": "arn:aws:iam::123456789012:role/MyRole"}""" + * client.post(uri, headers, body) + * }}} + */ + def post(uri: URI, headers: Map[String, String], body: String): F[HttpResponse] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala new file mode 100644 index 000000000..adc79de96 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/HttpResponse.scala @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +/** + * Represents an HTTP response from AWS API calls or other HTTP operations. + * + * This case class encapsulates the essential components of an HTTP response including + * the status code, response headers, and response body. It is designed to be immutable + * and type-safe, providing a clean abstraction over HTTP responses in the AWS + * authentication plugin. + * + * The response is typically created by HTTP client implementations and consumed by + * AWS service clients for processing API responses, credential retrieval, and + * authentication operations. + * + * @param statusCode the HTTP status code indicating the result of the request (e.g., 200, 404, 500) + * @param headers a map of HTTP response headers with case-insensitive keys + * @param body the response body content as a string, may be empty for certain responses + * + * @example {{{ + * // Creating a successful response + * val response = HttpResponse( + * statusCode = 200, + * headers = Map( + * "Content-Type" -> "application/json", + * "Content-Length" -> "123" + * ), + * body = """{"access_token": "abc123", "expires_in": 3600}""" + * ) + * + * // Checking response status + * if (response.statusCode >= 200 && response.statusCode < 300) { + * // Process successful response + * println(s"Success: ${response.body}") + * } + * + * // Accessing specific headers + * response.headers.get("Content-Type") match { + * case Some("application/json") => // Parse JSON response + * case _ => // Handle other content types + * } + * }}} + * + * @note HTTP header names should be treated as case-insensitive according to RFC 7230. + * Implementations should normalize header names appropriately. + * + * @see [[HttpClient]] for the interface that produces HttpResponse instances + * @see [[ldbc.amazon.useragent.BusinessMetricFeatureId]] for tracking HTTP client metrics + */ +final case class HttpResponse( + statusCode: Int, + headers: Map[String, String], + body: String +) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala new file mode 100644 index 000000000..80af00aee --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/client/StsClient.scala @@ -0,0 +1,262 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.{ URI, URLEncoder } +import java.time.{ Instant, ZoneOffset } +import java.time.format.DateTimeFormatter + +import cats.syntax.all.* +import cats.MonadThrow + +import cats.effect.std.UUIDGen +import cats.effect.Concurrent + +import ldbc.amazon.exception.StsException +import ldbc.amazon.util.SimpleXmlParser + +/** + * Trait for AWS STS (Security Token Service) client operations. + * + * This client implements the AWS STS AssumeRoleWithWebIdentity operation to exchange + * a Web Identity Token (JWT) for temporary AWS credentials. + */ +trait StsClient[F[_]]: + + /** + * Performs AssumeRoleWithWebIdentity operation. + * + * @param request The STS request parameters + * @return STS response with temporary credentials + */ + def assumeRoleWithWebIdentity( + request: StsClient.AssumeRoleWithWebIdentityRequest + ): F[StsClient.AssumeRoleWithWebIdentityResponse] + +object StsClient: + + /** + * AssumeRoleWithWebIdentity request parameters. + * + * @param roleArn The ARN of the IAM role to assume + * @param webIdentityToken The Web Identity Token (JWT) + * @param roleSessionName Optional session name for the assumed role session + * @param durationSeconds Optional duration of the session (default 3600 seconds) + */ + case class AssumeRoleWithWebIdentityRequest( + roleArn: String, + webIdentityToken: String, + roleSessionName: Option[String] = None, + durationSeconds: Option[Int] = None + ) + + /** + * STS response containing temporary credentials. + * + * @param accessKeyId The temporary access key ID + * @param secretAccessKey The temporary secret access key + * @param sessionToken The temporary session token + * @param expiration When the credentials expire + * @param assumedRoleArn The ARN of the assumed role + */ + case class AssumeRoleWithWebIdentityResponse( + accessKeyId: String, + secretAccessKey: String, + sessionToken: String, + expiration: Instant, + assumedRoleArn: String + ) + + /** + * Private implementation of StsClient. + * + * This implementation handles: + * 1. Request validation and parameter setup + * 2. HTTP request formatting and execution + * 3. Response parsing and error handling + * + * @param stsEndpoint The STS service endpoint URL + * @param httpClient The HTTP client to use for making requests + * @tparam F The effect type that supports UUID generation and concurrency + */ + private case class Impl[F[_]: UUIDGen: Concurrent]( + stsEndpoint: String, + httpClient: HttpClient[F] + ) extends StsClient[F]: + + override def assumeRoleWithWebIdentity( + request: AssumeRoleWithWebIdentityRequest + ): F[AssumeRoleWithWebIdentityResponse] = + for + _ <- validateRoleArn(request.roleArn) + sessionName <- request.roleSessionName.fold( + UUIDGen[F].randomUUID.map(uuid => s"ldbc-session-$uuid") + )(Concurrent[F].pure) + duration = request.durationSeconds.getOrElse(3600) + + // Build STS request + requestBody = buildRequestBody( + request.copy( + roleSessionName = Some(sessionName), + durationSeconds = Some(duration) + ) + ) + + // Make HTTP request + headers = Map( + "Content-Type" -> "application/x-www-form-urlencoded", + "X-Amz-Target" -> "AWSSecurityTokenServiceV20110615.AssumeRoleWithWebIdentity", + "X-Amz-Date" -> getCurrentTimestamp + ) + + response <- httpClient.post(URI.create(stsEndpoint), headers, requestBody) + _ <- validateHttpResponse(response) + stsResponse <- parseAssumeRoleResponse(response.body) + yield stsResponse + + /** + * Validates that the provided IAM role ARN has the correct format. + * + * Role ARNs must match the pattern: arn:aws:iam::ACCOUNT_ID:role/ROLE_NAME + * where ACCOUNT_ID is a 12-digit number and ROLE_NAME contains valid characters. + * + * @param roleArn The IAM role ARN to validate + * @return Unit if valid, raises an error if invalid + */ + private def validateRoleArn(roleArn: String): F[Unit] = + val roleArnPattern = """^arn:aws:iam::\d{12}:role/[\w+=,.@-]+$""".r + roleArnPattern.findFirstIn(roleArn) match { + case Some(_) => Concurrent[F].unit + case None => + Concurrent[F].raiseError( + new IllegalArgumentException( + s"An error occurred (ValidationError) when calling the AssumeRole operation: $roleArn is invalid" + ) + ) + } + + /** + * Creates implementation of StsClient. + * + * @param endpoint STS Endpoint + * @param httpClient HTTP client for making requests + * @tparam F The effect type + * @return A StsClient instance + */ + def build[F[_]: UUIDGen: Concurrent](endpoint: String, httpClient: HttpClient[F]): StsClient[F] = + Impl[F](endpoint, httpClient) + + /** + * Builds the STS request body in AWS Query format. + * + * Creates a URL-encoded form body with the required STS parameters for the + * AssumeRoleWithWebIdentity operation. All parameter names and values are + * properly URL-encoded to ensure safe transmission. + * + * @param request The STS request containing the parameters + * @return URL-encoded form data string for the request body + */ + private def buildRequestBody(request: AssumeRoleWithWebIdentityRequest): String = + val params = Map( + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, + "WebIdentityToken" -> request.webIdentityToken, + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + ) + + params + .map { + case (key, value) => + s"${ URLEncoder.encode(key, "UTF-8") }=${ URLEncoder.encode(value, "UTF-8") }" + } + .mkString("&") + + /** + * Gets current timestamp in AWS format. + * + * Generates a timestamp string in the ISO 8601 format required by AWS: + * yyyyMMddTHHmmssZ (e.g., "20231201T120000Z"). The timestamp is always + * generated in UTC timezone. + * + * @return formatted timestamp string + */ + private def getCurrentTimestamp: String = + DateTimeFormatter + .ofPattern("yyyyMMdd'T'HHmmss'Z'") + .withZone(ZoneOffset.UTC) + .format(Instant.now()) + + /** + * Validates HTTP response status. + * + * Checks if the HTTP response has a successful status code (2xx range). + * If the status code indicates an error, raises an StsException with + * the status code and response body for debugging. + * + * @param response The HTTP response to validate + * @return Unit if status is successful, raises StsException otherwise + */ + private def validateHttpResponse[F[_]: MonadThrow](response: HttpResponse): F[Unit] = + if response.statusCode >= 200 && response.statusCode < 300 then MonadThrow[F].unit + else + MonadThrow[F].raiseError( + new StsException( + s"STS request failed with status ${ response.statusCode }: ${ response.body }" + ) + ) + + /** + * Parses STS XML response to extract credentials. + * + * Expected XML structure: + * ```xml + * + * + * + * ASIA... + * ... + * ... + * 2023-12-01T12:00:00Z + * + * + * arn:aws:sts::123456789012:assumed-role/MyRole/MySession + * AROA....:MySession + * + * + * + * ``` + */ + private def parseAssumeRoleResponse[F[_]: MonadThrow](xmlBody: String): F[AssumeRoleWithWebIdentityResponse] = + MonadThrow[F] + .catchNonFatal { + + val accessKeyId = SimpleXmlParser.requireTag("AccessKeyId", xmlBody, "AccessKeyId not found") + val secretAccessKey = SimpleXmlParser.requireTag("SecretAccessKey", xmlBody, "SecretAccessKey not found") + val sessionToken = SimpleXmlParser.requireTag("SessionToken", xmlBody, "SessionToken not found") + val expirationStr = SimpleXmlParser.requireTag("Expiration", xmlBody, "Expiration not found") + + val assumedRoleArn = SimpleXmlParser + .extractSection("AssumedRoleUser", xmlBody) + .flatMap(section => SimpleXmlParser.extractTagContent("Arn", section)) + .filter(_.nonEmpty) + .getOrElse(throw new StsException("AssumedRoleArn not found")) + + AssumeRoleWithWebIdentityResponse( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + sessionToken = sessionToken, + expiration = Instant.parse(expirationStr), + assumedRoleArn = assumedRoleArn + ) + } + .adaptError { + case ex: StsException => ex + case ex => + new StsException(s"Failed to parse STS response: ${ ex.getMessage }") + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala new file mode 100644 index 000000000..eb6036181 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/CredentialsFetchError.scala @@ -0,0 +1,61 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Exception thrown when AWS credentials cannot be fetched from any available credential provider. + * + * This exception represents a high-level failure in the AWS credentials resolution process and + * is typically thrown when: + * - All configured credential providers have failed to provide valid credentials + * - The credential provider chain has been exhausted without success + * - Network connectivity issues prevent access to credential sources + * - Authentication tokens have expired and cannot be refreshed + * - Required environment variables or configuration files are missing + * + * Common credential sources that may fail: + * - **Environment variables**: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` + * - **AWS credentials file**: `~/.aws/credentials` + * - **IAM instance profiles**: EC2 metadata service + * - **Web Identity Tokens**: EKS service account tokens, OIDC providers + * - **AWS STS**: AssumeRole, AssumeRoleWithWebIdentity operations + * - **Container credentials**: ECS task metadata endpoints + * + * This exception typically indicates that the application cannot authenticate with AWS services + * and will likely be unable to access any AWS resources until the credential issue is resolved. + * + * Troubleshooting steps: + * 1. Verify AWS configuration and credentials + * 2. Check network connectivity to AWS endpoints + * 3. Validate IAM permissions and trust relationships + * 4. Ensure credential files and environment variables are correctly set + * 5. Check for expired tokens or certificates + * + * Example scenarios: + * - EKS pod without proper IRSA configuration + * - EC2 instance without IAM instance profile + * - Local development environment with missing AWS configuration + * - Network policies blocking access to AWS credential endpoints + * + * @param message The detailed error message describing why credential fetching failed, + * including information about which credential sources were attempted + * and what specific failures occurred + */ +class CredentialsFetchError(message: String) extends Exception: + + /** + * Returns the error message for this exception. + * + * The message typically includes details about: + * - Which credential providers were attempted + * - The specific failure reason for each provider + * - Suggested remediation steps + * - Environment or configuration context + * + * @return The detailed error message describing the credential fetch failure + */ + override def getMessage: String = message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala new file mode 100644 index 000000000..9ac7dd46c --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/InvalidTokenException.scala @@ -0,0 +1,39 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Exception thrown when a Web Identity Token is invalid, malformed, or cannot be processed. + * + * This exception is typically thrown during token validation when: + * - The JWT token does not have the correct format (header.payload.signature) + * - The token file is empty or contains only whitespace + * - The token contains invalid characters or encoding issues + * - The JWT structure is corrupted or missing required components + * - Base64 decoding of JWT segments fails + * - JSON parsing of JWT header or payload fails + * + * Valid JWT token format (3 base64-encoded segments separated by dots): + * ``` + * eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature + * ``` + * + * Common scenarios that trigger this exception: + * - Token file contains non-JWT content (e.g., HTML error page, plain text) + * - Network issues resulted in partial token download + * - Token rotation occurred mid-process leaving stale content + * - File system corruption affecting the token file + * + * This exception extends [[WebIdentityTokenException]] and inherits [[NoStackTrace]] behavior + * for performance optimization during token validation workflows. + * + * @param message The detailed error message describing the specific validation failure, + * including information about which part of the token validation failed + */ +class InvalidTokenException( + message: String +) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala new file mode 100644 index 000000000..ceafdf7f5 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/SdkClientException.scala @@ -0,0 +1,27 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Base exception for AWS SDK client-side errors. + * + * This exception is thrown when errors occur on the client side, such as: + * - Invalid configuration + * - Network connectivity issues + * - Authentication failures + * - Missing required resources + * + * @param message A descriptive error message explaining the cause of the exception + */ +class SdkClientException(message: String) extends RuntimeException: + + /** + * Returns the error message for this exception. + * + * @return The descriptive error message + */ + override def getMessage: String = message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala new file mode 100644 index 000000000..fbff2f01a --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/StsException.scala @@ -0,0 +1,37 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Exception thrown when AWS STS (Security Token Service) operations fail. + * + * This exception is typically thrown when: + * - STS AssumeRoleWithWebIdentity request fails (HTTP 400/403/500 errors) + * - Invalid or expired Web Identity Token + * - IAM role doesn't exist or lacks required trust relationships + * - Network connectivity issues to STS endpoints + * - STS response parsing failures + * + * Common STS error scenarios: + * - HTTP 403: Token expired or invalid trust policy + * - HTTP 400: Malformed request or invalid parameters + * - HTTP 500: STS service temporary issues + * + * Example STS endpoint: + * ``` + * https://sts.us-east-1.amazonaws.com/ + * ``` + * + * This exception extends [[SdkClientException]] and implements [[NoStackTrace]] for improved performance + * when exception handling is frequent, particularly in AWS authentication scenarios where retries are common. + * + * @param message The detailed error message including STS response details, HTTP status codes, + * and any relevant context from the failed STS operation + */ +class StsException( + message: String +) extends SdkClientException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala new file mode 100644 index 000000000..837bf20ea --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileAccessException.scala @@ -0,0 +1,58 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Exception thrown when the Web Identity Token file exists but cannot be accessed due to + * permissions, I/O errors, or other file system issues. + * + * This exception is typically thrown in the following scenarios: + * - The file exists but cannot be read due to insufficient file system permissions + * - I/O errors occur while reading the token file (disk errors, network mount issues) + * - File system issues prevent access (read-only file systems, corrupted file systems) + * - File is locked by another process or application + * - Security policy restrictions prevent file access (SELinux, AppArmor, etc.) + * - File descriptor limits are exceeded + * + * Common environments and solutions: + * + * **Container environments (Docker, Kubernetes):** + * - Ensure the container user has read access to mounted token files + * - Verify that volume mounts are configured correctly + * - Check that security contexts allow file access + * + * **File permissions:** + * - Token files typically need 600 permissions (read for owner only) + * - Verify the process owner matches the file owner or has appropriate group access + * - Check parent directory permissions (must be executable to access files within) + * + * **File system issues:** + * - Ensure file systems are mounted correctly (especially network mounts) + * - Verify disk space and inode availability + * - Check for file system corruption or read-only states + * + * Example debugging steps: + * ```bash + * # Check file permissions + * ls -la /var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * # Test file access as the application user + * sudo -u app-user cat /var/run/secrets/eks.amazonaws.com/serviceaccount/token + * + * # Check file system mount status + * mount | grep /var/run/secrets + * ``` + * + * This exception extends [[WebIdentityTokenException]] and inherits [[NoStackTrace]] behavior + * to avoid performance overhead during frequent access attempts. + * + * @param message The detailed error message describing the specific access failure, + * including file path and permission details when available + */ +class TokenFileAccessException( + message: String +) extends WebIdentityTokenException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala new file mode 100644 index 000000000..54b5c3108 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/TokenFileNotFoundException.scala @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Exception thrown when the Web Identity Token file cannot be found at the specified location. + * + * This exception is typically thrown in the following scenarios: + * - The file path specified in `AWS_WEB_IDENTITY_TOKEN_FILE` environment variable does not exist + * - The token file has been moved, deleted, or renamed after the application started + * - The file path is incorrectly configured in the environment + * - Directory structure changes that affect the token file location + * - Container restart scenarios where mounted volumes are not yet available + * + * Common environments where this occurs: + * - **EKS with IRSA (IAM Roles for Service Accounts)**: Token files are mounted by the EKS service + * - **Fargate**: Token files provided via task metadata endpoint + * - **Local development**: When mimicking AWS environments with local token files + * + * Example usage in EKS/IRSA: + * ``` + * AWS_WEB_IDENTITY_TOKEN_FILE=/var/run/secrets/eks.amazonaws.com/serviceaccount/token + * AWS_ROLE_ARN=arn:aws:iam::123456789012:role/my-service-role + * ``` + * + * This exception extends [[WebIdentityTokenException]] and provides enhanced error messages + * that include the attempted file path when available, making debugging easier. + * + * @param message The detailed error message describing the file lookup failure + * @param tokenFilePath The path to the missing token file (optional). When provided, this path + * will be included in the error message returned by [[getMessage]] + */ +class TokenFileNotFoundException( + message: String, + tokenFilePath: Option[String] = None +) extends WebIdentityTokenException(message): + + /** + * Returns the error message for this exception, including the token file path when available. + * + * This method enhances the base error message by appending the token file path when it + * was provided during exception construction, making it easier to identify which specific + * file could not be found. + * + * @return The enhanced error message that includes the file path context when available + */ + override def getMessage: String = + tokenFilePath match + case Some(path) => s"$message (Token file path: $path)" + case None => message diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala new file mode 100644 index 000000000..e12b42ac1 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/exception/WebIdentityTokenException.scala @@ -0,0 +1,16 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.exception + +/** + * Base exception for Web Identity Token operations. + * + * @param message The error message + */ +abstract class WebIdentityTokenException( + message: String +) extends SdkClientException(message) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala new file mode 100644 index 000000000..3b55f0187 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentials.scala @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity + +/** + * Provides access to the AWS credentials used for accessing services: AWS access key ID and secret access key. These + * credentials are used to securely sign requests to services (e.g., AWS services) that use them for authentication. + * + *

For more details on AWS access keys, see: + * + * https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html#access-keys-and-secret-access-keys

+ */ +trait AwsCredentials extends AwsCredentialsIdentity diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala new file mode 100644 index 000000000..ecef0fca3 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsIdentity.scala @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity + +import ldbc.amazon.identity.internal.DefaultAwsCredentialsIdentity + +/** + * Provides access to the AWS credentials used for accessing services: AWS access key ID and secret access key. These + * credentials are used to securely sign requests to services (e.g., AWS services) that use them for authentication. + * + *

For more details on AWS access keys, see: + * + * https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html#access-keys-and-secret-access-keys

+ */ +trait AwsCredentialsIdentity extends Identity: + + /** + * Retrieve the AWS access key, used to identify the user interacting with services. + */ + def accessKeyId: String + + /** + * Retrieve the AWS secret access key, used to authenticate the user interacting with services. + */ + def secretAccessKey: String + + /** + * Retrieve the AWS account id associated with this credentials identity, if found. + */ + def accountId: Option[String] + +object AwsCredentialsIdentity: + + /** + * Constructs a new credentials object, with the specified AWS access key and AWS secret key. + * + * @param accessKeyId The AWS access key, used to identify the user interacting with services. + * @param secretAccessKey The AWS secret access key, used to authenticate the user interacting with services. + */ + def create(accessKeyId: String, secretAccessKey: String): AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( + accessKeyId = accessKeyId, + secretAccessKey = secretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala new file mode 100644 index 000000000..830eea731 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/AwsCredentialsProvider.scala @@ -0,0 +1,21 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity + +trait AwsCredentialsProvider[F[_]]: + + /** + * Returns [[AwsCredentials]] that can be used to authorize an AWS request. Each implementation of AWSCredentialsProvider + * can choose its own strategy for loading credentials. For example, an implementation might load credentials from an existing + * key management system, or load new credentials when credentials are rotated. + * + *

If an error occurs during the loading of credentials or credentials could not be found, a runtime exception will be + * raised.

+ * + * @return AwsCredentials which the caller can use to authorize an AWS request. + */ + def resolveCredentials(): F[AwsCredentials] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala new file mode 100644 index 000000000..821e23888 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/Identity.scala @@ -0,0 +1,30 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity + +import java.time.Instant + +/** + * Interface to represent who is using the SDK, i.e., the identity of the caller, used for authentication. + * + * Examples include [[AwsCredentialsIdentity]] and [[TokenIdentity]]. + */ +trait Identity: + + /** + * The time after which this identity will no longer be valid. If this is empty, + * an expiration time is not known (but the identity may still expire at some + * time in the future). + */ + def expirationTime: Option[Instant] + + /** + * The source that resolved this identity, normally an identity provider. Note that + * this string value would be set by an identity provider implementation and is + * intended to be used for tracking purposes. Avoid building logic on its value. + */ + def providerName: Option[String] diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala new file mode 100644 index 000000000..859065d94 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentity.scala @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity.internal + +import java.time.Instant +import java.util.Objects + +import ldbc.amazon.identity.AwsCredentialsIdentity + +/** + * Default implementation of AWS credentials identity. + * + * This implementation provides AWS credentials with complete identity information + * including access keys, account ID, expiration time, and provider name. It serves + * as the concrete implementation used throughout the AWS authentication plugin. + * + * The class implements proper equals/hashCode semantics for credential comparison + * and provides a secure toString representation that does not expose sensitive + * information like secret access keys. + * + * @param accessKeyId The AWS access key ID used for authentication + * @param secretAccessKey The AWS secret access key used for signing requests + * @param accountId Optional AWS account ID associated with these credentials + * @param expirationTime Optional expiration time for temporary credentials + * @param providerName Optional name of the credentials provider that created these credentials + */ +final case class DefaultAwsCredentialsIdentity( + accessKeyId: String, + secretAccessKey: String, + accountId: Option[String], + expirationTime: Option[Instant], + providerName: Option[String] +) extends AwsCredentialsIdentity: + + /** + * Returns a string representation of these credentials without exposing sensitive information. + * + * The secret access key is intentionally omitted from the string representation for security. + * Only the access key ID, provider name, and account ID are included. + * + * @return A secure string representation of the credentials identity + */ + override def toString: String = + val builder = new StringBuilder() + builder.append("AwsCredentialsIdentity(") + builder.append(s"accessKeyId=$accessKeyId") + providerName.foreach(v => builder.append(s", providerName=$v")) + accountId.foreach(v => builder.append(s", accountId=$v")) + builder.append(")") + + builder.result() + + /** + * Compares this credentials identity with another object for equality. + * + * Two credentials identities are considered equal if they have the same + * access key ID, secret access key, and account ID. The expiration time + * and provider name are not considered for equality comparison. + * + * @param obj The object to compare with this credentials identity + * @return true if the objects are equal, false otherwise + */ + override def equals(obj: Any): Boolean = + obj match + case that: DefaultAwsCredentialsIdentity => + (this eq that) || + (Objects.equals(accessKeyId, that.accessKeyId) && + Objects.equals(secretAccessKey, that.secretAccessKey) && + Objects.equals(accountId, that.accountId)) + case _ => false + + /** + * Returns a hash code value for this credentials identity. + * + * The hash code is computed based on the access key ID, secret access key, + * and account ID to ensure consistency with the equals method. The expiration + * time and provider name are not included in the hash code calculation. + * + * @return A hash code value for this credentials identity + */ + override def hashCode(): Int = + var hashCode = 1 + hashCode = 31 * hashCode + Objects.hashCode(accessKeyId) + hashCode = 31 * hashCode + Objects.hashCode(secretAccessKey) + hashCode = 31 * hashCode + Objects.hashCode(accountId) + hashCode diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala new file mode 100644 index 000000000..21c125509 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/plugin/AwsIamAuthenticationPlugin.scala @@ -0,0 +1,143 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.plugin + +import java.nio.charset.StandardCharsets + +import scodec.bits.ByteVector + +import cats.syntax.all.* +import cats.Monad + +import cats.effect.* +import cats.effect.std.{ Env, SystemProperties, UUIDGen } + +import fs2.hashing.Hashing +import fs2.io.file.Files +import fs2.io.net.* + +import ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain +import ldbc.amazon.auth.token.{ AuthTokenGenerator, RdsIamAuthTokenGenerator } +import ldbc.amazon.identity.AwsCredentialsProvider +import ldbc.authentication.plugin.MysqlClearPasswordPlugin + +/** + * AWS IAM authentication plugin for connecting to MySQL databases using IAM credentials. + * + * This plugin enables secure database connections to AWS RDS MySQL instances using AWS IAM + * credentials instead of traditional username/password authentication. It extends the + * `MysqlClearPasswordPlugin` and generates temporary authentication tokens using AWS credentials. + * + * The plugin automatically: + * - Resolves AWS credentials from various sources (environment variables, profiles, IAM roles, etc.) + * - Generates time-limited authentication tokens signed with AWS credentials + * - Handles token refresh and credential rotation transparently + * + * Security benefits: + * - No hardcoded database passwords in application code + * - Leverages existing AWS IAM policies and roles + * - Tokens are automatically time-limited (15 minutes) + * - Supports AWS credential rotation and temporary credentials + * + * Usage requirements: + * - AWS RDS instance must have IAM authentication enabled + * - Database user must be created with IAM authentication + * - Application must have appropriate IAM permissions + * + * @tparam F The effect type that wraps authentication operations + * @param provider The AWS credentials provider for obtaining authentication credentials + * @param generator The token generator for creating RDS IAM authentication tokens + * + * @see [[https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html AWS RDS IAM Database Authentication]] + */ +final class AwsIamAuthenticationPlugin[F[_]: Monad]( + provider: AwsCredentialsProvider[F], + generator: AuthTokenGenerator[F] +) extends MysqlClearPasswordPlugin[F]: + + /** + * Generates an AWS IAM authentication token instead of processing a traditional password. + * + * This method overrides the standard password hashing behavior to generate a temporary + * authentication token using AWS IAM credentials. The process involves: + * + * 1. Resolving AWS credentials from the configured provider + * 2. Generating a signed authentication token using the RDS IAM authentication protocol + * 3. Converting the token to bytes for transmission to the MySQL server + * + * The generated token is a pre-signed URL that contains: + * - The RDS endpoint hostname and port + * - The database username + * - AWS signature version 4 (SigV4) authentication parameters + * - A 15-minute expiration time + * + * @param password The original password parameter (ignored in IAM authentication) + * @param scramble The server's challenge bytes (ignored in IAM authentication) + * @return The AWS IAM authentication token as UTF-8 encoded bytes wrapped in the effect type F + * + * @note The password and scramble parameters are inherited from the parent trait but are not + * used in AWS IAM authentication since the token generation process uses AWS credentials + * and cryptographic signing instead of password-based authentication. + */ + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + for + credentials <- provider.resolveCredentials() + token <- generator.generateToken(credentials) + yield ByteVector(token.getBytes(StandardCharsets.UTF_8)) + +object AwsIamAuthenticationPlugin: + + /** + * Creates a default AWS IAM authentication plugin with standard credential resolution. + * + * This factory method constructs an `AwsIamAuthenticationPlugin` with the default AWS + * credential provider chain and RDS IAM token generator. The credential provider chain + * attempts to resolve AWS credentials from multiple sources in the following order: + * + * 1. System properties (aws.accessKeyId, aws.secretAccessKey, aws.sessionToken) + * 2. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN) + * 3. Web identity token file (for EKS service accounts, Lambda, etc.) + * 4. AWS profile configuration files (~/.aws/credentials, ~/.aws/config) + * 5. Amazon EC2 instance profile credentials + * 6. ECS container credentials + * + * The token generator creates pre-signed authentication tokens using AWS Signature Version 4 + * that are valid for 15 minutes and specific to the provided RDS endpoint and database user. + * + * @tparam F The effect type that must support file operations, environment access, system + * properties, network operations, UUID generation, and asynchronous operations + * @param region The AWS region where the RDS instance is located (e.g., "us-east-1") + * @param hostname The RDS instance endpoint hostname (e.g., "mydb.abc123.us-east-1.rds.amazonaws.com") + * @param username The database username configured for IAM authentication + * @param port The database port number (default: 3306 for MySQL) + * @return A configured `AwsIamAuthenticationPlugin` instance ready for database authentication + * + * @example {{{ + * import cats.effect.IO + * import ldbc.amazon.plugin.AwsIamAuthenticationPlugin + * + * val plugin = AwsIamAuthenticationPlugin.default[IO]( + * region = "us-east-1", + * hostname = "mydb.abc123.us-east-1.rds.amazonaws.com", + * username = "myuser", + * port = 3306 + * ) + * }}} + * + * @see [[ldbc.amazon.auth.credentials.DefaultCredentialsProviderChain]] for credential resolution details + * @see [[ldbc.amazon.auth.token.RdsIamAuthTokenGenerator]] for token generation implementation + */ + def default[F[_]: Files: Hashing: Env: SystemProperties: Network: UUIDGen: Async]( + region: String, + hostname: String, + username: String, + port: Int = 3306 + ): AwsIamAuthenticationPlugin[F] = + new AwsIamAuthenticationPlugin[F]( + DefaultCredentialsProviderChain.default[F](region), + new RdsIamAuthTokenGenerator[F](hostname, port, username, region) + ) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala new file mode 100644 index 000000000..9a8d1e6c3 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/useragent/BusinessMetricFeatureId.scala @@ -0,0 +1,194 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.useragent + +/** + * Enumeration of business metric feature identifiers for AWS SDK user agent tracking. + * + * This enum represents a comprehensive set of short-form codes used to identify specific + * features, capabilities, and configurations in the AWS SDK user agent string. These + * metrics help AWS understand which features are being used and how the SDK is configured, + * enabling better service optimization and support. + * + * The feature identifiers cover several categories: + * - SDK features (waiters, paginators, transfer utilities) + * - Retry modes (legacy, standard, adaptive) + * - Compression and optimization features + * - Authentication and credential providers + * - Endpoint and account ID handling modes + * - Checksum validation methods + * - Protocol-specific features + * + * Each feature is assigned a unique single-character or short code that is embedded + * in the user agent string sent with AWS API requests. + * + * @param code the short identifier code used in the user agent string + * + * @example {{{ + * // Getting the code for a specific feature + * val retryCode = BusinessMetricFeatureId.RETRY_MODE_STANDARD.code // Returns "E" + * + * // Using in user agent construction + * val features = List( + * BusinessMetricFeatureId.PAGINATOR, + * BusinessMetricFeatureId.GZIP_REQUEST_COMPRESSION + * ) + * val codes = features.map(_.code).mkString(",") // "C,L" + * }}} + * + * @note Unimplemented metrics: I, K + * @note Unsupported metrics (will never be added): A, H + * + * @see [[ldbc.amazon.client.HttpClient]] for HTTP client implementations that may use these metrics + */ +enum BusinessMetricFeatureId(val code: String): + /** Indicates usage of SDK waiter functionality for polling operations. */ + case WAITER extends BusinessMetricFeatureId("B") + + /** Indicates usage of SDK paginator functionality for handling paginated API responses. */ + case PAGINATOR extends BusinessMetricFeatureId("C") + + /** Indicates usage of legacy retry mode with basic exponential backoff. */ + case RETRY_MODE_LEGACY extends BusinessMetricFeatureId("D") + + /** Indicates usage of standard retry mode with improved backoff strategies. */ + case RETRY_MODE_STANDARD extends BusinessMetricFeatureId("E") + + /** Indicates usage of adaptive retry mode with dynamic rate adjustment. */ + case RETRY_MODE_ADAPTIVE extends BusinessMetricFeatureId("F") + + /** Indicates usage of S3 transfer manager for optimized file uploads/downloads. */ + case S3_TRANSFER extends BusinessMetricFeatureId("G") + + /** Indicates usage of GZIP compression for request bodies. */ + case GZIP_REQUEST_COMPRESSION extends BusinessMetricFeatureId("L") + + /** Indicates usage of RPC v2 protocol with CBOR encoding. */ + case PROTOCOL_RPC_V2_CBOR extends BusinessMetricFeatureId("M") + + /** Indicates usage of custom endpoint override configuration. */ + case ENDPOINT_OVERRIDE extends BusinessMetricFeatureId("N") + + /** Indicates usage of S3 Express One Zone bucket features. */ + case S3_EXPRESS_BUCKET extends BusinessMetricFeatureId("J") + + /** Indicates account ID endpoint mode is set to preferred. */ + case ACCOUNT_ID_MODE_PREFERRED extends BusinessMetricFeatureId("P") + + /** Indicates account ID endpoint mode is disabled. */ + case ACCOUNT_ID_MODE_DISABLED extends BusinessMetricFeatureId("Q") + + /** Indicates account ID endpoint mode is required. */ + case ACCOUNT_ID_MODE_REQUIRED extends BusinessMetricFeatureId("R") + + /** Indicates usage of Signature Version 4A (SigV4A) signing for multi-region requests. */ + case SIGV4A_SIGNING extends BusinessMetricFeatureId("S") + + /** Indicates that account ID has been resolved for the request. */ + case RESOLVED_ACCOUNT_ID extends BusinessMetricFeatureId("T") + + /** Indicates usage of CRC32 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC32 extends BusinessMetricFeatureId("U") + + /** Indicates usage of CRC32C checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC32C extends BusinessMetricFeatureId("V") + + /** Indicates usage of CRC64 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_CRC64 extends BusinessMetricFeatureId("W") + + /** Indicates usage of SHA1 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_SHA1 extends BusinessMetricFeatureId("X") + + /** Indicates usage of SHA256 checksum for request validation. */ + case FLEXIBLE_CHECKSUMS_REQ_SHA256 extends BusinessMetricFeatureId("Y") + + /** Indicates flexible request checksums are calculated when supported. */ + case FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED extends BusinessMetricFeatureId("Z") + + /** Indicates flexible request checksums are calculated when required. */ + case FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED extends BusinessMetricFeatureId("a") + + /** Indicates flexible response checksums are validated when supported. */ + case FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED extends BusinessMetricFeatureId("b") + + /** Indicates flexible response checksums are validated when required. */ + case FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED extends BusinessMetricFeatureId("c") + + /** Indicates usage of DynamoDB Object Mapper functionality. */ + case DDB_MAPPER extends BusinessMetricFeatureId("d") + + /** Indicates bearer token credentials loaded from environment variables. */ + case BEARER_SERVICE_ENV_VARS extends BusinessMetricFeatureId("3") + + /** Indicates credentials provided directly in application code. */ + case CREDENTIALS_CODE extends BusinessMetricFeatureId("e") + + /** Indicates credentials loaded from JVM system properties. */ + case CREDENTIALS_JVM_SYSTEM_PROPERTIES extends BusinessMetricFeatureId("f") + + /** Indicates credentials loaded from environment variables. */ + case CREDENTIALS_ENV_VARS extends BusinessMetricFeatureId("g") + + /** Indicates credentials loaded from environment variables with STS web identity token. */ + case CREDENTIALS_ENV_VARS_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("h") + + /** Indicates credentials obtained via STS AssumeRole operation. */ + case CREDENTIALS_STS_ASSUME_ROLE extends BusinessMetricFeatureId("i") + + /** Indicates credentials obtained via STS AssumeRoleWithSAML operation. */ + case CREDENTIALS_STS_ASSUME_ROLE_SAML extends BusinessMetricFeatureId("j") + + /** Indicates credentials obtained via STS AssumeRoleWithWebIdentity operation. */ + case CREDENTIALS_STS_ASSUME_ROLE_WEB_ID extends BusinessMetricFeatureId("k") + + /** Indicates credentials obtained via STS GetFederationToken operation. */ + case CREDENTIALS_STS_FEDERATION_TOKEN extends BusinessMetricFeatureId("l") + + /** Indicates credentials obtained via STS GetSessionToken operation. */ + case CREDENTIALS_STS_SESSION_TOKEN extends BusinessMetricFeatureId("m") + + /** Indicates credentials loaded from AWS profile configuration. */ + case CREDENTIALS_PROFILE extends BusinessMetricFeatureId("n") + + /** Indicates credentials loaded from AWS profile with source profile configuration. */ + case CREDENTIALS_PROFILE_SOURCE_PROFILE extends BusinessMetricFeatureId("o") + + /** Indicates credentials loaded from AWS profile with named credential provider. */ + case CREDENTIALS_PROFILE_NAMED_PROVIDER extends BusinessMetricFeatureId("p") + + /** Indicates credentials loaded from AWS profile with STS web identity token. */ + case CREDENTIALS_PROFILE_STS_WEB_ID_TOKEN extends BusinessMetricFeatureId("q") + + /** Indicates credentials loaded from AWS profile with SSO configuration. */ + case CREDENTIALS_PROFILE_SSO extends BusinessMetricFeatureId("r") + + /** Indicates credentials obtained via AWS SSO. */ + case CREDENTIALS_SSO extends BusinessMetricFeatureId("s") + + /** Indicates credentials loaded from AWS profile with legacy SSO configuration. */ + case CREDENTIALS_PROFILE_SSO_LEGACY extends BusinessMetricFeatureId("t") + + /** Indicates credentials obtained via legacy AWS SSO. */ + case CREDENTIALS_SSO_LEGACY extends BusinessMetricFeatureId("u") + + /** Indicates credentials loaded from AWS profile with credential process. */ + case CREDENTIALS_PROFILE_PROCESS extends BusinessMetricFeatureId("v") + + /** Indicates credentials obtained via credential process. */ + case CREDENTIALS_PROCESS extends BusinessMetricFeatureId("w") + + /** Indicates credentials obtained via HTTP credential provider. */ + case CREDENTIALS_HTTP extends BusinessMetricFeatureId("z") + + /** Indicates credentials obtained from EC2 Instance Metadata Service (IMDS). */ + case CREDENTIALS_IMDS extends BusinessMetricFeatureId("0") + + /** Indicates credentials obtained from container metadata service. */ + case CREDENTIALS_CONTAINER extends BusinessMetricFeatureId("1") + + /** Indicates an unknown or unrecognized business metric feature. */ + case UNKNOWN extends BusinessMetricFeatureId("Unknown") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala new file mode 100644 index 000000000..7c1119681 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleJsonParser.scala @@ -0,0 +1,223 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +/** + * Simple JSON parser for AWS credential responses. + * + * This parser is designed specifically for parsing AWS service responses that contain + * credentials and other simple data structures. It avoids external dependencies to + * maintain compatibility with Scala.js and Scala Native environments. + * + * The parser handles flat JSON objects with string, number, boolean, and null values + * but does not support nested objects or arrays. + */ +object SimpleJsonParser: + + /** + * Represents a parsed JSON object with string-valued fields. + * + * This class provides convenient methods for accessing JSON fields with + * proper null handling and error reporting for missing required fields. + * + * @param fields Map of field names to optional string values (None represents JSON null) + */ + case class JsonObject(fields: Map[String, Option[String]]): + + /** + * Gets an optional field value. + * + * @param key The field name to retrieve + * @return Some(value) if the field exists and is not null, None otherwise + */ + def get(key: String): Option[String] = fields.get(key).flatten + + /** + * Gets a field value with empty string as fallback. + * + * @param key The field name to retrieve + * @return The field value or empty string if the field is missing or null + */ + def getOrEmpty(key: String): String = fields.get(key).flatten.getOrElse("") + + /** + * Requires a field to exist with a non-null value. + * + * This method is useful for extracting required fields from JSON responses + * with proper error messaging. + * + * @param key The field name to retrieve + * @return Right(value) if the field exists and is not null, Left(error message) otherwise + */ + def require(key: String): Either[String, String] = + fields.get(key) match + case None => Left(s"Required field '$key' not found") + case Some(None) => Left(s"Field '$key' is null") + case Some(Some(value)) => Right(value) + + /** + * Parses a flat JSON object with string values. + * + * Supports: + * - String values + * - Number values (returned as string) + * - Boolean values (returned as "true"/"false") + * - Null values (returned as empty string) + * + * Does NOT support: + * - Nested objects + * - Arrays + */ + def parse(json: String): Either[String, JsonObject] = + try + val trimmed = json.trim + if !trimmed.startsWith("{") || !trimmed.endsWith("}") then Left("Invalid JSON: must be an object") + else + val content = trimmed.substring(1, trimmed.length - 1).trim + if content.isEmpty then Right(JsonObject(Map.empty)) + else + val fields = parseFields(content) + Right(JsonObject(fields)) + catch case ex: Exception => Left(s"JSON parse error: ${ ex.getMessage }") + + private def parseFields(content: String): Map[String, Option[String]] = + val result = scala.collection.mutable.Map[String, Option[String]]() + var idx = 0 + + while idx < content.length do + // Skip whitespace + idx = skipWhitespace(content, idx) + if idx < content.length && content.charAt(idx) != '}' then + // Parse key + val (key, nextIdx1) = parseString(content, idx) + idx = skipWhitespace(content, nextIdx1) + + // Expect colon + if idx >= content.length || content.charAt(idx) != ':' then + throw new IllegalArgumentException(s"Expected ':' after key '$key' at position $idx") + + idx = skipWhitespace(content, idx + 1) + + // Parse value + val (value, nextIdx2) = parseValue(content, idx) + result(key) = Option(value) + idx = skipWhitespace(content, nextIdx2) + + // Skip comma if present + if idx < content.length && content.charAt(idx) == ',' then idx += 1 + + result.toMap + + private def skipWhitespace(s: String, from: Int): Int = + var idx = from + while idx < s.length && s.charAt(idx).isWhitespace do idx += 1 + idx + + private def parseString(s: String, from: Int): (String, Int) = + if from >= s.length || s.charAt(from) != '"' then + throw new IllegalArgumentException(s"Expected '\"' at position $from") + + val sb = new StringBuilder + var idx = from + 1 + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then + ch match + case '"' => sb.append('"') + case '\\' => sb.append('\\') + case '/' => sb.append('/') + case 'b' => sb.append('\b') + case 'f' => sb.append('\f') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'u' => + if idx + 4 >= s.length then throw new IllegalArgumentException("Invalid unicode escape sequence") + val hex = s.substring(idx + 1, idx + 5) + sb.append(Integer.parseInt(hex, 16).toChar) + idx += 4 + case _ => throw new IllegalArgumentException(s"Invalid escape sequence: \\$ch") + escaped = false + else if ch == '\\' then escaped = true + else if ch == '"' then return (sb.toString, idx + 1) + else sb.append(ch) + idx += 1 + + throw new IllegalArgumentException("Unterminated string") + + private def parseValue(s: String, from: Int): (String, Int) = + if from >= s.length then throw new IllegalArgumentException(s"Unexpected end of input at position $from") + + val ch = s.charAt(from) + + if ch == '"' then + // String value + parseString(s, from) + else if ch == 'n' && s.length >= from + 4 && s.substring(from, from + 4) == "null" then + // null value + (null, from + 4) + else if ch == 't' && s.length >= from + 4 && s.substring(from, from + 4) == "true" then ("true", from + 4) + else if ch == 'f' && s.length >= from + 5 && s.substring(from, from + 5) == "false" then ("false", from + 5) + else if ch == '-' || ch.isDigit then + // Number value + var idx = from + while idx < s.length && isNumberChar(s.charAt(idx)) do idx += 1 + (s.substring(from, idx), idx) + else if ch == '{' then + // Nested object - skip it entirely + val endIdx = findMatchingBrace(s, from) + ("{...}", endIdx + 1) + else if ch == '[' then + // Array - skip it entirely + val endIdx = findMatchingBracket(s, from) + ("[...]", endIdx + 1) + else throw new IllegalArgumentException(s"Unexpected character '$ch' at position $from") + + private def isNumberChar(ch: Char): Boolean = + ch.isDigit || ch == '.' || ch == '-' || ch == '+' || ch == 'e' || ch == 'E' + + private def findMatchingBrace(s: String, from: Int): Int = + var depth = 0 + var idx = from + var inString = false + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then escaped = false + else if ch == '\\' && inString then escaped = true + else if ch == '"' then inString = !inString + else if !inString then + if ch == '{' then depth += 1 + else if ch == '}' then + depth -= 1 + if depth == 0 then return idx + idx += 1 + + throw new IllegalArgumentException("Unmatched brace") + + private def findMatchingBracket(s: String, from: Int): Int = + var depth = 0 + var idx = from + var inString = false + var escaped = false + + while idx < s.length do + val ch = s.charAt(idx) + if escaped then escaped = false + else if ch == '\\' && inString then escaped = true + else if ch == '"' then inString = !inString + else if !inString then + if ch == '[' then depth += 1 + else if ch == ']' then + depth -= 1 + if depth == 0 then return idx + idx += 1 + + throw new IllegalArgumentException("Unmatched bracket") diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala new file mode 100644 index 000000000..71604e36c --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SimpleXmlParser.scala @@ -0,0 +1,104 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +/** + * Simple XML parser utility for extracting content from XML documents. + * + * This parser provides basic XML parsing functionality without requiring a full XML library. + * It's designed specifically for parsing AWS service responses and handles common XML patterns. + * The parser performs basic entity decoding and content extraction but does not validate + * XML structure or handle complex XML features like namespaces, CDATA, or processing instructions. + */ +object SimpleXmlParser: + + /** + * Decodes standard XML entities to their corresponding characters. + * + * This method replaces common XML entities with their actual character representations: + * - & → & + * - < → < + * - > → > + * - " → " + * - ' → ' + * + * @param s The string containing XML entities to decode + * @return The string with XML entities decoded to their character equivalents + */ + def decodeXmlEntities(s: String): String = + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace(""", "\"") + .replace("'", "'") + + /** + * Extracts the content between XML tags. + * + * Finds the first occurrence of the specified XML tag and extracts its content. + * The content is trimmed and XML entities are decoded. This method only handles + * simple tags without attributes and does not support nested tags with the same name. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @return Some(content) if the tag is found with content, None otherwise + */ + def extractTagContent(tagName: String, xml: String): Option[String] = { + val startTag = s"<$tagName>" + val endTag = s"" + val startIdx = xml.indexOf(startTag) + + if startIdx < 0 then None + else { + val contentStart = startIdx + startTag.length + val endIdx = xml.indexOf(endTag, contentStart) + if endIdx < 0 then None + else Some(decodeXmlEntities(xml.substring(contentStart, endIdx).trim)) + } + } + + /** + * Extracts a complete XML section including the opening and closing tags. + * + * Finds the first occurrence of the specified XML tag and extracts the entire + * section including the tags themselves. This is useful for extracting nested + * XML structures that need further processing. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @return Some(section) if the tag section is found, None otherwise + */ + def extractSection(tagName: String, xml: String): Option[String] = { + val startTag = s"<$tagName>" + val endTag = s"" + val startIdx = xml.indexOf(startTag) + + if startIdx < 0 then None + else { + val endIdx = xml.indexOf(endTag, startIdx) + if endIdx < 0 then None + else Some(xml.substring(startIdx, endIdx + endTag.length)) + } + } + + /** + * Extracts tag content and throws an exception if the tag is missing or empty. + * + * This is a strict version of extractTagContent that requires the tag to exist + * and have non-empty content. If the tag is missing, empty, or contains only + * whitespace, an IllegalArgumentException is thrown with the provided error message. + * + * @param tagName The name of the XML tag (without angle brackets) + * @param xml The XML string to search in + * @param errorMsg The error message to use if the tag is missing or empty + * @return The tag content if found and non-empty + * @throws IllegalArgumentException if the tag is missing, empty, or contains only whitespace + */ + def requireTag(tagName: String, xml: String, errorMsg: String): String = + extractTagContent(tagName, xml) + .filter(_.nonEmpty) + .getOrElse(throw new IllegalArgumentException(errorMsg)) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala new file mode 100644 index 000000000..913abb747 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/main/scala/ldbc/amazon/util/SystemSetting.scala @@ -0,0 +1,275 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +/** + * Base trait for system setting definitions. This trait serves as a marker interface + * for all system setting implementations in the AWS authentication plugin. + * + * System settings provide a way to configure AWS SDK behavior through system properties + * and environment variables. Implementations of this trait define the available + * configuration options and their default values. + * + * @see [[SdkSystemSetting]] for the concrete implementation with AWS SDK settings + */ +trait SystemSetting + +/** + * Enumeration of AWS SDK system settings that can be configured through system properties. + * + * This enum defines all the AWS SDK configuration options available through system properties, + * providing a type-safe way to access configuration values. Each setting includes both the + * system property name and an optional default value. + * + * The settings cover various aspects of AWS SDK configuration including: + * - Credential configuration (access keys, session tokens, roles) + * - Regional settings and endpoint configuration + * - Metadata service settings + * - HTTP client configuration + * - Optimization and feature toggles + * - Authentication and signing preferences + * + * @param systemProperty the name of the system property used to configure this setting + * @param defaultValue the default value to use if the system property is not set, None if no default + * + * @example {{{ + * // Access the system property name + * val propertyName = SdkSystemSetting.AWS_REGION.systemProperty + * + * // Get the default value + * val default = SdkSystemSetting.AWS_EC2_METADATA_DISABLED.defaultValue + * }}} + * + * @see [[SystemSetting]] + * @see [[ldbc.amazon.client.HttpClient]] for HTTP client configuration + */ +enum SdkSystemSetting(val systemProperty: String, val defaultValue: Option[String]): + /** + * Configure the AWS access key ID. + * + * This value will not be ignored if the [[AWS_SECRET_ACCESS_KEY]] is not specified. + */ + case AWS_ACCESS_KEY_ID extends SdkSystemSetting("aws.accessKeyId", None) + + /** + * Configure the AWS secret access key. + * + * This value will not be ignored if the [[AWS_ACCESS_KEY_ID]] is not specified. + */ + case AWS_SECRET_ACCESS_KEY extends SdkSystemSetting("aws.secretAccessKey", None) + + /** + * Configure the AWS session token. + */ + case AWS_SESSION_TOKEN extends SdkSystemSetting("aws.sessionToken", None) + + /** + * Configure the AWS account id associated with credentials supplied through system properties. + */ + case AWS_ACCOUNT_ID extends SdkSystemSetting("aws.accountId", None) + + /** + * Configure the AWS web identity token file path. + */ + case AWS_WEB_IDENTITY_TOKEN_FILE extends SdkSystemSetting("aws.webIdentityTokenFile", None) + + /** + * Configure the AWS role arn. + */ + case AWS_ROLE_ARN extends SdkSystemSetting("aws.roleArn", None) + + /** + * Configure the session name for a role. + */ + case AWS_ROLE_SESSION_NAME extends SdkSystemSetting("aws.roleSessionName", None) + + /** + * Configure the default region. + */ + case AWS_REGION extends SdkSystemSetting("aws.region", None) + + /** + * Whether to load information such as credentials, regions from EC2 Metadata instance service. + */ + case AWS_EC2_METADATA_DISABLED extends SdkSystemSetting("aws.disableEc2Metadata", Some("false")) + + /** + * Whether to disable fallback to insecure EC2 Metadata instance service v1 on errors or timeouts. + */ + case AWS_EC2_METADATA_V1_DISABLED extends SdkSystemSetting("aws.disableEc2MetadataV1", None) + + /** + * The EC2 instance metadata service endpoint. + * + * This allows a service running in EC2 to automatically load its credentials and region without needing to configure them + * in the SdkClientBuilder. + */ + case AWS_EC2_METADATA_SERVICE_ENDPOINT + extends SdkSystemSetting("aws.ec2MetadataServiceEndpoint", Some("http://169.254.169.254")) + + case AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE + extends SdkSystemSetting("aws.ec2MetadataServiceEndpointMode", Some("IPv4")) + + /** + * The number of seconds (either as an integer or double) before a connection to the instance + * metadata service should time out. This value is applied to both the socket connect and read timeouts. + * + * The timeout can be configured using the system property "aws.ec2MetadataServiceTimeout". If not set, + * a default timeout is used. This setting is crucial for ensuring timely responses from the instance + * metadata service in environments with varying network conditions. + */ + case AWS_METADATA_SERVICE_TIMEOUT extends SdkSystemSetting("aws.ec2MetadataServiceTimeout", Some("1")) + + /** + * The elastic container metadata service endpoint that should be called by the ContainerCredentialsProvider + * when loading data from the container metadata service. + * + * This allows a service running in an elastic container to automatically load its credentials without needing to configure + * them in the SdkClientBuilder. + * + * This is not used if the [[AWS_CONTAINER_CREDENTIALS_RELATIVE_URI]] is not specified. + */ + case AWS_CONTAINER_SERVICE_ENDPOINT + extends SdkSystemSetting("aws.containerServiceEndpoint", Some("http://169.254.170.2")) + + /** + * The elastic container metadata service path that should be called by the ContainerCredentialsProvider when + * loading credentials form the container metadata service. If this is not specified, credentials will not be automatically + * loaded from the container metadata service. + * + * @see #AWS_CONTAINER_SERVICE_ENDPOINT + */ + case AWS_CONTAINER_CREDENTIALS_RELATIVE_URI extends SdkSystemSetting("aws.containerCredentialsPath", None) + + /** + * The full URI path to a localhost metadata service to be used. + */ + case AWS_CONTAINER_CREDENTIALS_FULL_URI extends SdkSystemSetting("aws.containerCredentialsFullUri", None) + + /** + * An authorization token to pass to a container metadata service, only used when [[AWS_CONTAINER_CREDENTIALS_FULL_URI]] + * is specified. + * + * @see #AWS_CONTAINER_CREDENTIALS_FULL_URI + */ + case AWS_CONTAINER_AUTHORIZATION_TOKEN extends SdkSystemSetting("aws.containerAuthorizationToken", None) + + /** + * The absolute file path containing the authorization token in plain text to pass to a container metadata + * service, only used when [[AWS_CONTAINER_CREDENTIALS_FULL_URI]] is specified. + * + * @see #AWS_CONTAINER_CREDENTIALS_FULL_URI + */ + case AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE extends SdkSystemSetting("aws.containerAuthorizationTokenFile", None) + + /** + * Explicitly identify the default synchronous HTTP implementation the SDK will use. Useful + * when there are multiple implementations on the classpath or as a performance optimization + * since implementation discovery requires classpath scanning. + */ + case SYNC_HTTP_SERVICE_IMPL extends SdkSystemSetting("software.amazon.awssdk.http.service.impl", None) + + /** + * Explicitly identify the default Async HTTP implementation the SDK will use. Useful + * when there are multiple implementations on the classpath or as a performance optimization + * since implementation discovery requires classpath scanning. + */ + case ASYNC_HTTP_SERVICE_IMPL extends SdkSystemSetting("software.amazon.awssdk.http.async.service.impl", None) + + /** + * Whether CBOR optimization should automatically be used if its support is found on the classpath and the service supports + * CBOR-formatted JSON. + */ + case CBOR_ENABLED extends SdkSystemSetting("aws.cborEnabled", Some("true")) + + /** + * Whether binary ION representation optimization should automatically be used if the service supports ION. + */ + case BINARY_ION_ENABLED extends SdkSystemSetting("aws.binaryIonEnabled", Some("true")) + + /** + * The execution environment of the SDK user. This is automatically set in certain environments by the underlying AWS service. + * For example, AWS Lambda will automatically specify a runtime indicating that the SDK is being used within Lambda. + */ + case AWS_EXECUTION_ENV extends SdkSystemSetting("aws.executionEnvironment", None) + + /** + * Whether endpoint discovery should be enabled. + */ + case AWS_ENDPOINT_DISCOVERY_ENABLED extends SdkSystemSetting("aws.endpointDiscoveryEnabled", None) + + /** + * Which [[RetryMode]] to use for the default [[RetryPolicy]], when one is not specified at the client level. + */ + case AWS_RETRY_MODE extends SdkSystemSetting("aws.retryMode", None) + + /** + * Which DefaultsMode to use, case insensitive + */ + case AWS_DEFAULTS_MODE extends SdkSystemSetting("aws.defaultsMode", None) + + /** + * Which AccountIdEndpointMode to use, case insensitive + */ + case AWS_ACCOUNT_ID_ENDPOINT_MODE extends SdkSystemSetting("aws.accountIdEndpointMode", None) + + /** + * Defines whether dualstack endpoints should be resolved during default endpoint resolution instead of non-dualstack + * endpoints. + */ + case AWS_USE_DUALSTACK_ENDPOINT extends SdkSystemSetting("aws.useDualstackEndpoint", None) + + /** + * Defines whether fips endpoints should be resolved during default endpoint resolution instead of non-fips endpoints. + */ + case AWS_USE_FIPS_ENDPOINT extends SdkSystemSetting("aws.useFipsEndpoint", None) + + /** + * Whether request compression is disabled for operations marked with the RequestCompression trait. The default value is + * false, i.e., request compression is enabled. + */ + case AWS_DISABLE_REQUEST_COMPRESSION extends SdkSystemSetting("aws.disableRequestCompression", None) + + /** + * Defines the minimum compression size in bytes, inclusive, for a request to be compressed. The default value is 10_240. + * The value must be non-negative and no greater than 10_485_760. + */ + case AWS_REQUEST_MIN_COMPRESSION_SIZE_BYTES extends SdkSystemSetting("aws.requestMinCompressionSizeBytes", None) + + /** + * Defines a file path from which partition metadata should be loaded. If this isn't specified, the partition + * metadata deployed with the SDK client will be used instead. + */ + case AWS_PARTITIONS_FILE extends SdkSystemSetting("aws.partitionsFile", None) + + /** + * The request checksum calculation setting. The default value is WHEN_SUPPORTED. + */ + case AWS_REQUEST_CHECKSUM_CALCULATION extends SdkSystemSetting("aws.requestChecksumCalculation", None) + + /** + * The response checksum validation setting. The default value is WHEN_SUPPORTED. + */ + case AWS_RESPONSE_CHECKSUM_VALIDATION extends SdkSystemSetting("aws.responseChecksumValidation", None) + + /** + * Configure an optional identification value to be appended to the user agent header. + * The value should be less than 50 characters in length and is None by default. + */ + case AWS_SDK_UA_APP_ID extends SdkSystemSetting("sdk.ua.appId", None) + + /** + * Configure the SIGV4A signing region set. + * This is a non-empty, comma-delimited list of AWS region names used during signing. + */ + case AWS_SIGV4A_SIGNING_REGION_SET extends SdkSystemSetting("aws.sigv4a.signing.region.set", None) + + /** + * Configure the preferred auth scheme to use. + * This is a comma-delimited list of AWS auth scheme names used during signing. + */ + case AWS_AUTH_SCHEME_PREFERENCE extends SdkSystemSetting("aws.authSchemePreference", None) diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala new file mode 100644 index 000000000..fc102bf32 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ContainerCredentialsProviderTest.scala @@ -0,0 +1,533 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import cats.effect.{ IO, Ref } +import cats.effect.std.Env + +import fs2.io.file.{ Files, Path } + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.AwsSessionCredentials +import ldbc.amazon.client.{ HttpClient, HttpResponse } +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials + +class ContainerCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val validJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "arn:aws:iam::123456789012:role/test-role" + }""" + + private val minimalJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z" + }""" + + private val invalidJsonResponse = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE" + }""" + + private val eksEndpoint = "http://169.254.170.23/v1/credentials" + private val authToken = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." + + // Mock HTTP client + private def mockHttpClient( + responseBody: String, + statusCode: Int = 200, + captureRequest: Option[Ref[IO, Option[MockRequest]]] = None + ): HttpClient[IO] = + new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + for _ <- captureRequest match + case Some(ref) => ref.set(Some(MockRequest(uri, headers))) + case None => IO.unit + yield HttpResponse( + statusCode = statusCode, + headers = Map.empty, + body = responseBody + ) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported in mock")) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("PUT not supported in mock")) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + // Use real Files instance and mock file operations differently + + case class MockRequest(uri: URI, headers: Map[String, String]) + + test("resolveCredentials with ECS relative URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + // Verify credentials + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(session.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(session.sessionToken, "session-token-123") + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("1")) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(Instant.parse("2024-12-31T23:59:59Z"))) + case _ => fail("Expected AwsSessionCredentials") + + // Verify HTTP request + captured match + case Some(request) => + assertEquals(request.uri.toString, "http://169.254.170.2/v2/credentials/test-uuid") + assertEquals(request.headers("Authorization"), authToken) + assertEquals(request.headers("Accept"), "application/json") + assertEquals(request.headers("User-Agent"), "aws-sdk-scala/ldbc") + case None => fail("Expected captured request") + } + + test("resolveCredentials with EKS full URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint, + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(minimalJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + // Verify credentials without RoleArn + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(session.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(session.sessionToken, "session-token-123") + assertEquals(session.accountId, None) // No RoleArn means no account ID + case _ => fail("Expected AwsSessionCredentials") + + // Verify HTTP request + captured match + case Some(request) => + assertEquals(request.uri.toString, eksEndpoint) + assertEquals(request.headers("Authorization"), authToken) + case None => fail("Expected captured request") + } + + test("resolveCredentials with token from direct environment variable") { + // This test works in both JVM and JavaScript environments + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> s" $authToken \n" // With whitespace + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify token was trimmed and used (ContainerCredentialsProvider trims the token) + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), authToken.trim) + case None => fail("Expected captured request") + } + + test("resolveCredentials without authorization token") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + // No authorization token + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify no Authorization header + captured match + case Some(request) => + assert(!request.headers.contains("Authorization")) + case None => fail("Expected captured request") + } + + test("fail when no container credentials environment variables are set") { + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(validJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load container credentials")) + assert(exception.getMessage.contains("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")) + assert(exception.getMessage.contains("AWS_CONTAINER_CREDENTIALS_FULL_URI")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when HTTP request fails") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient("Internal Server Error", statusCode = 500) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Container credentials request failed")) + assert(exception.getMessage.contains("500")) + assert(exception.getMessage.contains("Internal Server Error")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when JSON response is invalid") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(invalidJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to parse container credentials response")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when token file does not exist") { + val tokenFilePath = Path("/tmp/non-existent-token") + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> tokenFilePath.toString + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(validJsonResponse) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for credentials <- provider.resolveCredentials() + yield + // Should succeed without token (Authorization header optional) + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + } + + test("handle empty authorization token") { + // Test with empty environment variable + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> " \n " // Only whitespace + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify no Authorization header for empty token + captured match + case Some(request) => + assert(!request.headers.contains("Authorization")) + case None => fail("Expected captured request") + } + + test("prefer direct token over token file path") { + // Test priority without actual file operations + val directToken = "direct-token" + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> directToken, + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/some/file/path" // File path provided but direct token should take precedence + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify direct token was used, not file token + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), directToken) + case None => fail("Expected captured request") + } + + test("prefer relative URI over full URI") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify relative URI was used (ECS endpoint) + captured match + case Some(request) => + assertEquals(request.uri.toString, "http://169.254.170.2/v2/credentials/test-uuid") + case None => fail("Expected captured request") + } + + test("extractAccountIdFromRoleArn extracts account ID correctly") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val responseWithValidArn = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "arn:aws:iam::987654321098:role/test-role" + }""" + + val httpClient = mockHttpClient(responseWithValidArn) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accountId, Some("987654321098")) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("extractAccountIdFromRoleArn handles invalid ARN") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val responseWithInvalidArn = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "session-token-123", + "Expiration": "2024-12-31T23:59:59Z", + "RoleArn": "invalid-arn-format" + }""" + + val httpClient = mockHttpClient(responseWithInvalidArn) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accountId, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("isAvailable returns true when relative URI is set") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, true) + } + } + + test("isAvailable returns true when full URI is set") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> eksEndpoint + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, true) + } + } + + test("isAvailable returns false when no URIs are set") { + given Env[IO] = mockEnv(Map.empty) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, false) + } + } + + test("isAvailable returns false when URIs are empty") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> " ", + "AWS_CONTAINER_CREDENTIALS_FULL_URI" -> "" + ) + + given Env[IO] = mockEnv(envVars) + + ContainerCredentialsProvider.isAvailable[IO]().map { available => + assertEquals(available, false) + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid" + ) + + given Env[IO] = mockEnv(envVars) + + val httpClient = mockHttpClient(validJsonResponse) + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + ContainerCredentialsProvider.create[IO](httpClient) + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + // JavaScript-compatible tests for file-based functionality + test("resolveCredentials with token file path (JavaScript)") { + // In JavaScript environment, test with direct token instead of file + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> authToken + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify token was used + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), authToken) + case None => fail("Expected captured request") + } + + test("token preference logic (JavaScript)") { + // Test that direct token takes precedence over file token path + val envVars = Map( + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" -> "/v2/credentials/test-uuid", + "AWS_CONTAINER_AUTHORIZATION_TOKEN" -> "direct-token", + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" -> "/non/existent/file" + ) + + given Env[IO] = mockEnv(envVars) + + val requestCapture = Ref.unsafe[IO, Option[MockRequest]](None) + val httpClient = mockHttpClient(validJsonResponse, captureRequest = Some(requestCapture)) + val provider = ContainerCredentialsProvider.create[IO](httpClient) + + for + credentials <- provider.resolveCredentials() + captured <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => // Success + case _ => fail("Expected AwsSessionCredentials") + + // Verify direct token was used + captured match + case Some(request) => + assertEquals(request.headers("Authorization"), "direct-token") + case None => fail("Expected captured request") + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala new file mode 100644 index 000000000..5e574bf2b --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/EnvironmentVariableCredentialsProviderTest.scala @@ -0,0 +1,378 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.effect.std.Env +import cats.effect.IO + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials +import ldbc.amazon.util.SdkSystemSetting + +class EnvironmentVariableCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + test("resolveCredentials with basic credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.validateCredentials, false) + assertEquals(basic.providerName, Some("g")) + assertEquals(basic.accountId, None) + assertEquals(basic.expirationTime, None) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("resolveCredentials with session credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> testSessionToken + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("g")) + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("fail when AWS_ACCESS_KEY_ID is missing") { + val envVars = Map( + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when AWS_SECRET_ACCESS_KEY is missing") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Secret key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when no environment variables are set") { + given Env[IO] = mockEnv(Map.empty) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("handle empty environment variable values") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> "", + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Empty string is accepted + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle whitespace-only environment variable values") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> " ", + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Trimmed to empty string + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle empty session token") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> "" + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Empty string, not None + case _ => fail("Expected AwsSessionCredentials even with empty session token") + } + } + + test("handle whitespace-only session token") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> " \n " + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Trimmed to empty string + case _ => fail("Expected AwsSessionCredentials even with whitespace session token") + } + } + + test("trim whitespace from credentials") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> s" $testAccessKeyId ", + "AWS_SECRET_ACCESS_KEY" -> s"\n$testSecretAccessKey\t", + "AWS_SESSION_TOKEN" -> s" $testSessionToken \n" + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("provider name is correct") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assertEquals(credentials.providerName, Some("g")) // BusinessMetricFeatureId.CREDENTIALS_ENV_VARS.code + } + } + + test("loadSetting method reads from environment") { + val testValue = "test-setting-value" + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testValue + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, Some(testValue)) + } + } + + test("loadSetting returns None for missing environment variable") { + given Env[IO] = mockEnv(Map.empty) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, None) + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + test("consistent behavior across multiple calls") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey, + "AWS_SESSION_TOKEN" -> testSessionToken + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + for + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + yield + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + credentials1 match + case session1: AwsSessionCredentials => + credentials2 match + case session2: AwsSessionCredentials => + assertEquals(session1.sessionToken, session2.sessionToken) + case _ => fail("Both credentials should be AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") + } + + test("validates credentials format") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + // Verify access key format (starts with AKIA for IAM users) + assert(credentials.accessKeyId.startsWith("AKIA")) + assert(credentials.accessKeyId.length >= 16) + + // Verify secret key is not empty and has reasonable length + assert(credentials.secretAccessKey.nonEmpty) + assert(credentials.secretAccessKey.length >= 20) + } + } + + test("case sensitivity of environment variable names") { + val envVars = Map( + "aws_access_key_id" -> testAccessKeyId, // lowercase + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + // Should fail because AWS expects exact case + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load")) + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with incorrect case environment variable") + } + } + + test("alternative environment variable names are not supported") { + val envVars = Map( + "AWS_ACCESS_KEY" -> testAccessKeyId, // Not AWS_ACCESS_KEY_ID + "AWS_SECRET_KEY" -> testSecretAccessKey // Not AWS_SECRET_ACCESS_KEY + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load")) + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with non-standard environment variable names") + } + } + + test("credentials validation is disabled by default") { + val envVars = Map( + "AWS_ACCESS_KEY_ID" -> testAccessKeyId, + "AWS_SECRET_ACCESS_KEY" -> testSecretAccessKey + ) + + given Env[IO] = mockEnv(envVars) + + val provider = new EnvironmentVariableCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.validateCredentials, false) + case session: AwsSessionCredentials => + assertEquals(session.validateCredentials, false) + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala new file mode 100644 index 000000000..e57db2970 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/InstanceProfileCredentialsProviderTest.scala @@ -0,0 +1,571 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.net.URI +import java.time.Instant + +import cats.effect.{ IO, Ref } +import cats.effect.std.Env + +import munit.CatsEffectSuite + +import ldbc.amazon.client.{ HttpClient, HttpResponse } +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials + +class InstanceProfileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleName = "test-role" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val testMetadataToken = "AQAEAFUsWMxSzWnKOL7wJKWJHFGDL+gUSCMd6FKPYa8wOYwR=" + private val futureExpiration = Instant.now().plusSeconds(3600) + + // Sample IMDS JSON response + private val validCredentialsResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$testAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${ futureExpiration.toString }" + }""" + + private val failedCredentialsResponse = s"""{ + "Code": "Failed", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "", + "SecretAccessKey": "", + "Token": "", + "Expiration": "" + }""" + + // Mock HTTP client + case class MockRequest(uri: URI, headers: Map[String, String], method: String, body: Option[String]) + + private def mockHttpClient( + responses: Map[String, HttpResponse], + captureRequest: Option[Ref[IO, List[MockRequest]]] = None + ): HttpClient[IO] = + new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "GET", None)) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(404, Map.empty, "Not Found") + ) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "POST", Some(body))) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(404, Map.empty, "Not Found") + ) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + for _ <- captureRequest match + case Some(ref) => ref.update(_ :+ MockRequest(uri, headers, "PUT", Some(body))) + case None => IO.unit + yield responses.getOrElse( + uri.toString, + HttpResponse(200, Map.empty, testMetadataToken) + ) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + test("resolveCredentials with successful IMDSv2 flow") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("0")) // BusinessMetricFeatureId.CREDENTIALS_IMDS.code + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the request flow + val tokenRequest = requests.find(_.method == "PUT") + assert(tokenRequest.isDefined) + assert(tokenRequest.get.headers.contains("X-aws-ec2-metadata-token-ttl-seconds")) + + val roleRequest = requests.find(_.uri.toString.endsWith("security-credentials/")) + assert(roleRequest.isDefined) + assert(roleRequest.get.headers.contains("X-aws-ec2-metadata-token")) + + val credRequest = requests.find(_.uri.toString.endsWith(s"security-credentials/$testRoleName")) + assert(credRequest.isDefined) + assert(credRequest.get.headers.contains("X-aws-ec2-metadata-token")) + } + + test("resolveCredentials with IMDSv1 fallback when token acquisition fails") { + val responses = Map( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported")) + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + // Simulate token acquisition failure + IO.raiseError(new Exception("Token acquisition failed")) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + + test("fail when EC2 metadata is disabled") { + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_DISABLED" -> "true")) + + val httpClient = mockHttpClient(Map.empty) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("EC2 metadata service is disabled")) + case _ => fail("Expected SdkClientException") + } + + test("fail when no IAM roles are available") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 200, + Map.empty, + "" + ) // Empty response + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("No IAM roles found")) + case _ => fail("Expected SdkClientException") + } + + test("fail when instance profile is not attached (403)") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(403, Map.empty, "Forbidden") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Forbidden (403)")) + assert(exception.getMessage.contains("No instance profile attached")) + case _ => fail("Expected SdkClientException") + } + + test("fail when metadata service is not available (404)") { + val responses = Map( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(404, Map.empty, "Not Found") + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Not Found (404)")) + assert(exception.getMessage.contains("Instance metadata not available")) + case _ => fail("Expected SdkClientException") + } + + test("fail when metadata token is invalid (401)") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 401, + Map.empty, + "Unauthorized" + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unauthorized (401)")) + assert(exception.getMessage.contains("Invalid or expired metadata token")) + case _ => fail("Expected SdkClientException") + } + + test("fail when credentials response has failed status") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + failedCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to retrieve credentials")) + assert(exception.getMessage.contains("Failed")) + case _ => fail("Expected SdkClientException") + } + + test("fail when credentials response is malformed JSON") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + "invalid json" + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + result <- provider.resolveCredentials().attempt + yield result match + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Failed to parse")) + case _ => fail("Expected SdkClientException") + } + + test("use custom IMDS endpoint when environment variable is set") { + val customEndpoint = "http://169.254.169.254:8080" + val responses = Map( + s"$customEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + s"$customEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"$customEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify custom endpoint was used + assert(requests.exists(_.uri.toString.startsWith(customEndpoint))) + } + + test("strip trailing slash from custom endpoint") { + val customEndpoint = "http://169.254.169.254:8080/" + val expectedEndpoint = "http://169.254.169.254:8080" + val responses = Map( + s"$expectedEndpoint/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + s"$expectedEndpoint/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"$expectedEndpoint/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map("AWS_EC2_METADATA_SERVICE_ENDPOINT" -> customEndpoint)) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify trailing slash was stripped + assert(requests.exists(_.uri.toString.startsWith(expectedEndpoint))) + assert(!requests.exists(_.uri.toString.contains("//latest"))) + } + + test("caching works correctly - multiple calls return cached credentials") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + requests <- requestCapture.get + yield + // Both calls should return the same credentials + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + + // Only one set of requests should have been made (first call) + val tokenRequests = requests.filter(_.method == "PUT") + assertEquals(tokenRequests.length, 1) + } + + test("credentials are refreshed when close to expiration") { + val shortExpirationTime = Instant.now().plusSeconds(180) // 3 minutes from now (less than 4 minute buffer) + val shortExpirationResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:00:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$testAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${ shortExpirationTime.toString }" + }""" + + val longExpirationTime = Instant.now().plusSeconds(3600) // 1 hour from now + val newAccessKeyId = "ASIANEWACCESSKEY123" + val refreshedResponse = s"""{ + "Code": "Success", + "LastUpdated": "2024-01-01T12:30:00Z", + "Type": "AWS-HMAC", + "AccessKeyId": "$newAccessKeyId", + "SecretAccessKey": "$testSecretAccessKey", + "Token": "$testSessionToken", + "Expiration": "${ longExpirationTime.toString }" + }""" + + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName) + ) + + // First response with short expiration, second with long expiration + val callCount = Ref.unsafe[IO, Int](0) + val httpClient = new HttpClient[IO]: + override def get(uri: URI, headers: Map[String, String]): IO[HttpResponse] = + if uri.toString.endsWith(s"security-credentials/$testRoleName") then + callCount.updateAndGet(_ + 1).map { count => + if count == 1 then HttpResponse(200, Map.empty, shortExpirationResponse) + else HttpResponse(200, Map.empty, refreshedResponse) + } + else IO.pure(responses.getOrElse(uri.toString, HttpResponse(404, Map.empty, "Not Found"))) + + override def post(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.raiseError(new UnsupportedOperationException("POST not supported")) + + override def put(uri: URI, headers: Map[String, String], body: String): IO[HttpResponse] = + IO.pure(HttpResponse(200, Map.empty, testMetadataToken)) + + given Env[IO] = mockEnv(Map.empty) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials1 <- provider.resolveCredentials() // Gets short expiration credentials + credentials2 <- provider.resolveCredentials() // Should refresh due to short expiration + count <- callCount.get + yield + // First credentials should have short expiration + assertEquals(credentials1.expirationTime, Some(shortExpirationTime)) + + // Second credentials should be refreshed with new access key + assertEquals(credentials2.accessKeyId, newAccessKeyId) + assertEquals(credentials2.expirationTime, Some(longExpirationTime)) + + // Should have made 2 credential requests due to refresh + assertEquals(count, 2) + } + + test("implements AwsCredentialsProvider trait") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider + credentials <- providerTrait.resolveCredentials() + yield assert(credentials.isInstanceOf[AwsCredentials]) + } + + test("handles multiple IAM roles by selecting the first one") { + val multipleRoles = s"$testRoleName\nrole2\nrole3" + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse( + 200, + Map.empty, + multipleRoles + ), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val httpClient = mockHttpClient(responses) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + } + + test("request headers are set correctly") { + val responses = Map( + "http://169.254.169.254/latest/api/token" -> HttpResponse(200, Map.empty, testMetadataToken), + "http://169.254.169.254/latest/meta-data/iam/security-credentials/" -> HttpResponse(200, Map.empty, testRoleName), + s"http://169.254.169.254/latest/meta-data/iam/security-credentials/$testRoleName" -> HttpResponse( + 200, + Map.empty, + validCredentialsResponse + ) + ) + + given Env[IO] = mockEnv(Map.empty) + + val requestCapture = Ref.unsafe[IO, List[MockRequest]](List.empty) + val httpClient = mockHttpClient(responses, Some(requestCapture)) + + for + provider <- InstanceProfileCredentialsProvider.create[IO](httpClient) + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify PUT request headers for token acquisition + val tokenRequest = requests.find(_.method == "PUT") + assert(tokenRequest.isDefined) + assertEquals(tokenRequest.get.headers("X-aws-ec2-metadata-token-ttl-seconds"), "21600") + + // Verify GET request headers include token and user agent + val getRequests = requests.filter(_.method == "GET") + getRequests.foreach { request => + assertEquals(request.headers("Accept"), "application/json") + assertEquals(request.headers("User-Agent"), "aws-sdk-scala/ldbc") + assertEquals(request.headers("X-aws-ec2-metadata-token"), testMetadataToken) + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala new file mode 100644 index 000000000..9346e8cf5 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/ProfileCredentialsProviderTest.scala @@ -0,0 +1,85 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import cats.effect.std.SystemProperties +import cats.effect.IO + +import fs2.io.file.Files + +import munit.CatsEffectSuite + +import ProfileCredentialsProvider.* + +class ProfileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testHomeDir = "/home/test" + + // Mock system properties + private def mockSystemProperties(homeDir: Option[String] = Some(testHomeDir)): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + name match + case "user.home" => IO.pure(homeDir) + case _ => IO.pure(None) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + test("ProfileCredentialsProvider fails when user.home is missing") { + given SystemProperties[IO] = mockSystemProperties(homeDir = None) + + assertIOBoolean( + for + provider <- ProfileCredentialsProvider.default[IO]() + result <- provider.resolveCredentials().attempt + yield result.isLeft + ) + } + + test("ProfileCredentialsProvider companion object methods work") { + val profile = ProfileCredentialsProvider.Profile("test", Map("key" -> "value")) + assertEquals(profile.name, "test") + assertEquals(profile.properties, Map("key" -> "value")) + } + + test("ProfileFile case class creation works") { + + val profiles = Map("default" -> Profile("default", Map("aws_access_key_id" -> "test"))) + val instant = Instant.now() + val profileFile = ProfileFile(profiles, instant) + + assertEquals(profileFile.profiles, profiles) + assertEquals(profileFile.lastModified, instant) + } + + test("Profile case class creation works") { + + val properties = Map( + "aws_access_key_id" -> "AKIAIOSFODNN7EXAMPLE", + "aws_secret_access_key" -> "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + ) + val profile = Profile("production", properties) + + assertEquals(profile.name, "production") + assertEquals(profile.properties, properties) + } + + test("ProfileCredentialsProvider handles various profile names correctly") { + given SystemProperties[IO] = mockSystemProperties() + + val profileNames = List("default", "dev", "staging", "production", "test-profile", "profile_with_underscores") + + assertIO( + IO.traverse(profileNames)(ProfileCredentialsProvider.default[IO](_)).map(_.length), + profileNames.length + ) + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala new file mode 100644 index 000000000..6e4de9e81 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/SystemPropertyCredentialsProviderTest.scala @@ -0,0 +1,474 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import cats.effect.std.SystemProperties +import cats.effect.IO + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials +import ldbc.amazon.util.SdkSystemSetting + +class SystemPropertyCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + // Mock system properties + private def mockSystemProperties(sysProps: Map[String, String]): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(sysProps.get(name)) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + test("resolveCredentials with basic credentials") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.validateCredentials, false) + assertEquals(basic.providerName, Some("f")) + assertEquals(basic.accountId, None) + assertEquals(basic.expirationTime, None) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("resolveCredentials with session credentials") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("f")) + assertEquals(session.accountId, None) + assertEquals(session.expirationTime, None) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("resolveCredentials with account ID") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + assertEquals(basic.accountId, Some("123456789012")) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("fail when aws.accessKeyId is missing") { + val sysProps = Map( + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when aws.secretAccessKey is missing") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Secret key")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when no system properties are set") { + given SystemProperties[IO] = mockSystemProperties(Map.empty) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case _ => fail("Expected SdkClientException") + } + } + + test("handle empty system property values") { + val sysProps = Map( + "aws.accessKeyId" -> "", + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Empty string is accepted + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle whitespace-only system property values") { + val sysProps = Map( + "aws.accessKeyId" -> " ", + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, "") // Trimmed to empty string + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("handle empty session token") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> "" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Empty string, not None + case _ => fail("Expected AwsSessionCredentials even with empty session token") + } + } + + test("handle whitespace-only session token") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> " \n " + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, "") // Trimmed to empty string + case _ => fail("Expected AwsSessionCredentials even with whitespace session token") + } + } + + test("trim whitespace from credentials") { + val sysProps = Map( + "aws.accessKeyId" -> s" $testAccessKeyId ", + "aws.secretAccessKey" -> s"\n$testSecretAccessKey\t", + "aws.sessionToken" -> s" $testSessionToken \n" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("provider name is correct") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assertEquals( + credentials.providerName, + Some("f") + ) // BusinessMetricFeatureId.CREDENTIALS_JVM_SYSTEM_PROPERTIES.code + } + } + + test("loadSetting method reads from system properties") { + val testValue = "test-setting-value" + val sysProps = Map( + "aws.accessKeyId" -> testValue + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, Some(testValue)) + } + } + + test("loadSetting returns None for missing system property") { + given SystemProperties[IO] = mockSystemProperties(Map.empty) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID).map { result => + assertEquals(result, None) + } + } + + test("implements AwsCredentialsProvider trait") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider: ldbc.amazon.identity.AwsCredentialsProvider[IO] = + new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + assert(credentials.isInstanceOf[AwsCredentials]) + } + } + + test("consistent behavior across multiple calls") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + for + credentials1 <- provider.resolveCredentials() + credentials2 <- provider.resolveCredentials() + yield + assertEquals(credentials1.accessKeyId, credentials2.accessKeyId) + assertEquals(credentials1.secretAccessKey, credentials2.secretAccessKey) + credentials1 match + case session1: AwsSessionCredentials => + credentials2 match + case session2: AwsSessionCredentials => + assertEquals(session1.sessionToken, session2.sessionToken) + case _ => fail("Both credentials should be AwsSessionCredentials") + case _ => fail("Expected AwsSessionCredentials") + } + + test("validates credentials format") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { credentials => + // Verify access key format (starts with AKIA for IAM users) + assert(credentials.accessKeyId.startsWith("AKIA")) + assert(credentials.accessKeyId.length >= 16) + + // Verify secret key is not empty and has reasonable length + assert(credentials.secretAccessKey.nonEmpty) + assert(credentials.secretAccessKey.length >= 20) + } + } + + test("system property keys are correct") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + for + accessKey <- provider.loadSetting(SdkSystemSetting.AWS_ACCESS_KEY_ID) + secretKey <- provider.loadSetting(SdkSystemSetting.AWS_SECRET_ACCESS_KEY) + sessionToken <- provider.loadSetting(SdkSystemSetting.AWS_SESSION_TOKEN) + accountId <- provider.loadSetting(SdkSystemSetting.AWS_ACCOUNT_ID) + yield + assertEquals(accessKey, Some(testAccessKeyId)) + assertEquals(secretKey, Some(testSecretAccessKey)) + assertEquals(sessionToken, Some(testSessionToken)) + assertEquals(accountId, Some("123456789012")) + } + + test("different from environment variables - uses system properties") { + // This test ensures that SystemPropertyCredentialsProvider only reads from system properties, + // not environment variables, even if the system property names are similar + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, // System property format + "aws.secretAccessKey" -> testSecretAccessKey + // Note: NO AWS_ACCESS_KEY_ID (environment variable format) + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.accessKeyId, testAccessKeyId) + assertEquals(basic.secretAccessKey, testSecretAccessKey) + case _ => fail("Expected AwsBasicCredentials") + } + } + + test("case sensitivity of system property names") { + // System properties are case-sensitive in Java + val sysProps = Map( + "AWS.ACCESSKEYID" -> testAccessKeyId, // Wrong case + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + assert(exception.getMessage.contains("Access key")) + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with incorrect case system property") + } + } + + test("alternative system property names are not supported") { + val sysProps = Map( + "aws.access.key.id" -> testAccessKeyId, // Not aws.accessKeyId + "aws.secret.access.key" -> testSecretAccessKey // Not aws.secretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load credentials from system settings")) + case Left(other) => fail(s"Expected SdkClientException, got ${ other.getClass.getSimpleName }") + case Right(_) => fail("Should fail with non-standard system property names") + } + } + + test("credentials validation is disabled by default") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case basic: AwsBasicCredentials => + assertEquals(basic.validateCredentials, false) + case session: AwsSessionCredentials => + assertEquals(session.validateCredentials, false) + } + } + + test("handles all supported AWS system properties") { + val sysProps = Map( + "aws.accessKeyId" -> testAccessKeyId, + "aws.secretAccessKey" -> testSecretAccessKey, + "aws.sessionToken" -> testSessionToken, + "aws.accountId" -> "123456789012" + ) + + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val provider = new SystemPropertyCredentialsProvider[IO] + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, Some("f")) + assert(session.expirationTime.isEmpty) + case _ => fail("Expected AwsSessionCredentials") + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala new file mode 100644 index 000000000..31955ae26 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/WebIdentityTokenFileCredentialsProviderTest.scala @@ -0,0 +1,372 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials + +import java.time.Instant + +import cats.effect.{ IO, Ref } +import cats.effect.std.{ Env, SystemProperties } + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.internal.WebIdentityCredentialsUtils +import ldbc.amazon.auth.credentials.AwsSessionCredentials +import ldbc.amazon.exception.SdkClientException +import ldbc.amazon.identity.AwsCredentials + +class WebIdentityTokenFileCredentialsProviderTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleArn = "arn:aws:iam::123456789012:role/test-role" + private val testTokenFile = "/var/run/secrets/eks.amazonaws.com/serviceaccount/token" + private val testSessionName = "test-session" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val futureExpiration = Instant.now().plusSeconds(3600) + + // Mock environment + private def mockEnv(envVars: Map[String, String]): Env[IO] = + new Env[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(envVars.get(name)) + override def entries: IO[scala.collection.immutable.Iterable[(String, String)]] = + IO.pure(envVars) + + // Mock system properties + private def mockSystemProperties(sysProps: Map[String, String] = Map.empty): SystemProperties[IO] = + new SystemProperties[IO]: + override def get(name: String): IO[Option[String]] = + IO.pure(sysProps.get(name)) + override def clear(key: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("clear not supported in mock")) + override def set(key: String, value: String): IO[Option[String]] = + IO.raiseError(new UnsupportedOperationException("set not supported in mock")) + + // Mock WebIdentityCredentialsUtils + private def mockWebIdentityUtils( + response: Option[AwsCredentials] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[WebIdentityTokenCredentialProperties]]] = None + ): WebIdentityCredentialsUtils[IO] = + (config: WebIdentityTokenCredentialProperties) => + for _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ config) + case None => IO.unit + yield + if shouldFail then throw new SdkClientException(errorMessage) + else + response.getOrElse( + AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = false, + providerName = None, + accountId = Some("123456789012"), + expirationTime = Some(futureExpiration) + ) + ) + + test("resolveCredentials with successful Web Identity Token authentication") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn, + "AWS_ROLE_SESSION_NAME" -> testSessionName + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the correct configuration was passed + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.roleSessionName, Some(testSessionName)) + } + + test("resolveCredentials with system properties instead of environment variables") { + val sysProps = Map( + "aws.webIdentityTokenFile" -> testTokenFile, + "aws.roleArn" -> testRoleArn, + "aws.roleSessionName" -> testSessionName + ) + + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().map { + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("environment variables take precedence over system properties") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + val sysProps = Map( + "aws.webIdentityTokenFile" -> "/wrong/path", + "aws.roleArn" -> "arn:aws:iam::999999999999:role/wrong-role" + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify environment variables were used + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + } + + test("resolveCredentials without role session name uses None") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify no session name was provided + val request = requests.head + assertEquals(request.roleSessionName, None) + } + + test("fail when token file is not specified") { + val envVars = Map( + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + assert(exception.getMessage.contains("AWS_WEB_IDENTITY_TOKEN_FILE")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when role ARN is not specified") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + assert(exception.getMessage.contains("AWS_ROLE_ARN")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when both token file and role ARN are missing") { + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assert(exception.getMessage.contains("Unable to load Web Identity Token credentials")) + case _ => fail("Expected SdkClientException") + } + } + + test("fail when WebIdentityCredentialsUtils throws an exception") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils(shouldFail = true, errorMessage = "STS service unavailable") + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + provider.resolveCredentials().attempt.map { + case Left(exception: SdkClientException) => + assertEquals(exception.getMessage, "STS service unavailable") + case _ => fail("Expected SdkClientException") + } + } + + test("implements AwsCredentialsProvider trait") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val webIdentityUtils = mockWebIdentityUtils() + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + // Type check passes - provider implements the trait correctly + val providerTrait: ldbc.amazon.identity.AwsCredentialsProvider[IO] = provider + + assertIOBoolean(providerTrait.resolveCredentials().map(_.isInstanceOf[AwsCredentials])) + } + + test("isAvailable returns true when required configuration is present") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile, + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIOBoolean( + WebIdentityTokenFileCredentialsProvider.isAvailable[IO]() + ) + } + + test("isAvailable returns true with system properties") { + val sysProps = Map( + "aws.webIdentityTokenFile" -> testTokenFile, + "aws.roleArn" -> testRoleArn + ) + + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties(sysProps) + + assertIOBoolean(WebIdentityTokenFileCredentialsProvider.isAvailable[IO]()) + } + + test("isAvailable returns false when token file is missing") { + val envVars = Map( + "AWS_ROLE_ARN" -> testRoleArn + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when role ARN is missing") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> testTokenFile + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when both are missing") { + given Env[IO] = mockEnv(Map.empty) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("isAvailable returns false when values are empty or whitespace") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> " ", + "AWS_ROLE_ARN" -> "" + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + assertIO(WebIdentityTokenFileCredentialsProvider.isAvailable[IO](), false) + } + + test("handles whitespace trimming in configuration values") { + val envVars = Map( + "AWS_WEB_IDENTITY_TOKEN_FILE" -> s" $testTokenFile ", + "AWS_ROLE_ARN" -> s" $testRoleArn ", + "AWS_ROLE_SESSION_NAME" -> s" $testSessionName " + ) + + given Env[IO] = mockEnv(envVars) + given SystemProperties[IO] = mockSystemProperties() + + val requestCapture = Ref.unsafe[IO, List[WebIdentityTokenCredentialProperties]](List.empty) + val webIdentityUtils = mockWebIdentityUtils(captureRequestsTo = Some(requestCapture)) + val provider = WebIdentityTokenFileCredentialsProvider.create[IO](webIdentityUtils) + + for + credentials <- provider.resolveCredentials() + requests <- requestCapture.get + yield + credentials match + case _: AwsSessionCredentials => assert(true) + case _ => fail("Expected AwsSessionCredentials") + + // Verify trimmed values were used + val request = requests.head + assertEquals(request.webIdentityTokenFile.toString, testTokenFile) + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.roleSessionName, Some(testSessionName)) + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala new file mode 100644 index 000000000..3d47e0226 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/credentials/internal/WebIdentityCredentialsUtilsTest.scala @@ -0,0 +1,333 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.credentials.internal + +import java.time.Instant + +import cats.effect.{ IO, Ref } + +import fs2.io.file.Files + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.{ AwsSessionCredentials, WebIdentityTokenCredentialProperties } +import ldbc.amazon.client.StsClient +import ldbc.amazon.exception.{ InvalidTokenException, StsException } + +class WebIdentityCredentialsUtilsTest extends CatsEffectSuite: + + // Test fixtures + private val testRoleArn = "arn:aws:iam::123456789012:role/test-role" + private val testSessionName = "test-session" + private val testAccessKeyId = "ASIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + private val futureExpiration = Instant.now().plusSeconds(3600) + private val testAssumedRoleArn = "arn:aws:sts::123456789012:assumed-role/test-role/test-session" + + // Valid JWT token (simplified for testing) + private val validJwtToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6IjEyMyJ9.eyJpc3MiOiJodHRwczovL29pZGMuZWtzLnVzLWVhc3QtMS5hbWF6b25hd3MuY29tL2lkLzEyMyIsInN1YiI6InN5c3RlbTpzZXJ2aWNlYWNjb3VudDpkZWZhdWx0Om15LWFwcCJ9.signature" + + // For these tests, we'll focus on mocking the StsClient behavior since Files[IO] is sealed + // We'll create temporary files for file-based tests when needed + + // Mock StsClient + private def mockStsClient( + response: Option[StsClient.AssumeRoleWithWebIdentityResponse] = None, + shouldFail: Boolean = false, + errorMessage: String = "Mock STS error", + captureRequestsTo: Option[Ref[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]]] = None + ): StsClient[IO] = + (request: StsClient.AssumeRoleWithWebIdentityRequest) => + for _ <- captureRequestsTo match + case Some(ref) => ref.update(_ :+ request) + case None => IO.unit + yield + if shouldFail then throw new StsException(errorMessage) + else + response.getOrElse( + StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = testAssumedRoleArn + ) + ) + + test("assumeRoleWithWebIdentity with successful flow") { + // Create a temporary file with the JWT token for testing + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + requests <- requestCapture.get + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + assertEquals(session.secretAccessKey, testSecretAccessKey) + assertEquals(session.sessionToken, testSessionToken) + assertEquals(session.validateCredentials, false) + assertEquals(session.providerName, None) + assertEquals(session.accountId, Some("123456789012")) + assertEquals(session.expirationTime, Some(futureExpiration)) + case _ => fail("Expected AwsSessionCredentials") + + // Verify STS request was made correctly + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.roleArn, testRoleArn) + assertEquals(request.webIdentityToken, validJwtToken) + assertEquals(request.roleSessionName, Some(testSessionName)) + } + } + + // Focus on tests that can be validated without complex file mocking + // JWT validation tests use temporary files for realistic testing + + test("fail when JWT token has invalid format - too few parts") { + val invalidToken = "header.payload" // Missing signature part + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 2")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when JWT token has invalid format - too many parts") { + val invalidToken = "header.payload.signature.extra" // Too many parts + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("Invalid JWT token format")) + assert(exception.getMessage.contains("Expected 3 parts, got 4")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when JWT token has empty parts") { + val invalidToken = "header..signature" // Empty payload + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(invalidToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient() + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: InvalidTokenException) => + assert(exception.getMessage.contains("JWT token contains empty parts")) + case _ => fail("Expected InvalidTokenException") + } + } + + test("fail when STS client throws exception") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val stsClient = mockStsClient(shouldFail = true, errorMessage = "STS service unavailable") + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + result <- webIdentityUtils.assumeRoleWithWebIdentity(config).attempt + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield result match + case Left(exception: StsException) => + assertEquals(exception.getMessage, "STS service unavailable") + case _ => fail("Expected StsException") + } + } + + test("extract account ID from ARN correctly") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val customAssumedRoleArn = "arn:aws:sts::999888777666:assumed-role/my-role/my-session" + val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = customAssumedRoleArn + ) + + val stsClient = mockStsClient(response = Some(customResponse)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, Some("999888777666")) + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("handle malformed ARN gracefully") { + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(validJwtToken)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val malformedArn = "invalid:arn:format" + val customResponse = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + expiration = futureExpiration, + assumedRoleArn = malformedArn + ) + + val stsClient = mockStsClient(response = Some(customResponse)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield credentials match + case session: AwsSessionCredentials => + assertEquals(session.accountId, None) // Should be None for malformed ARN + case _ => fail("Expected AwsSessionCredentials") + } + } + + test("reads token from file with correct trimming") { + val tokenWithWhitespace = s" $validJwtToken \n\t" + + val tempFile = for + tempDir <- Files[IO].createTempDirectory(None, "webidentity-test", None) + tokenFile <- Files[IO].createTempFile(Some(tempDir), "jwt-token", ".txt", None) + _ <- Files[IO].writeUtf8(tokenFile).apply(fs2.Stream.emit(tokenWithWhitespace)).compile.drain + yield tokenFile + + tempFile.flatMap { tokenFilePath => + val config = WebIdentityTokenCredentialProperties( + webIdentityTokenFile = tokenFilePath, + roleArn = testRoleArn, + roleSessionName = Some(testSessionName) + ) + + val requestCapture = Ref.unsafe[IO, List[StsClient.AssumeRoleWithWebIdentityRequest]](List.empty) + val stsClient = mockStsClient(captureRequestsTo = Some(requestCapture)) + val webIdentityUtils = WebIdentityCredentialsUtils.create[IO](stsClient) + + for + credentials <- webIdentityUtils.assumeRoleWithWebIdentity(config) + requests <- requestCapture.get + _ <- Files[IO].deleteIfExists(tokenFilePath) // Cleanup + yield + credentials match + case session: AwsSessionCredentials => + assertEquals(session.accessKeyId, testAccessKeyId) + case _ => fail("Expected AwsSessionCredentials") + + // Verify the token was trimmed correctly + assertEquals(requests.length, 1) + val request = requests.head + assertEquals(request.webIdentityToken, validJwtToken) // Should be trimmed + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala new file mode 100644 index 000000000..d06129852 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/auth/token/RdsIamAuthTokenGeneratorTest.scala @@ -0,0 +1,499 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.auth.token + +import java.net.URLDecoder +import java.time.Instant + +import cats.effect.IO + +import fs2.hashing.Hashing + +import munit.CatsEffectSuite + +import ldbc.amazon.auth.credentials.{ AwsBasicCredentials, AwsSessionCredentials } + +class RdsIamAuthTokenGeneratorTest extends CatsEffectSuite: + + // Test fixtures + private val testHostname = "my-db-instance.ap-northeast-1.rds.amazonaws.com" + private val testPort = 3306 + private val testUsername = "db_user" + private val testRegion = "ap-northeast-1" + + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testSessionToken = "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3" + + private def createGenerator(): RdsIamAuthTokenGenerator[IO] = + new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = testRegion + ) + + test("generateToken with basic credentials") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) + ) + + generator.generateToken(credentials).map { token => + // Verify token structure + assert(token.startsWith(s"$testHostname:$testPort/?")) + assert(token.contains("Action=connect")) + assert(token.contains(s"DBUser=$testUsername")) + assert(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")) + assert(token.contains(s"X-Amz-Credential=$testAccessKeyId")) + assert(token.contains("X-Amz-Date=")) + assert(token.contains("X-Amz-Expires=900")) + assert(token.contains("X-Amz-SignedHeaders=host")) + assert(token.contains("X-Amz-Signature=")) + + // Should NOT contain session token for basic credentials + assert(!token.contains("X-Amz-Security-Token")) + } + } + + test("generateToken with session credentials") { + val generator = createGenerator() + val credentials = AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = true, + providerName = Some("test-provider"), + accountId = Some("123456789012"), + expirationTime = Some(Instant.now().plusSeconds(3600)) + ) + + generator.generateToken(credentials).map { token => + // Verify token structure + assert(token.startsWith(s"$testHostname:$testPort/?")) + assert(token.contains("Action=connect")) + assert(token.contains(s"DBUser=$testUsername")) + assert(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")) + assert(token.contains(s"X-Amz-Credential=$testAccessKeyId")) + assert(token.contains("X-Amz-Date=")) + assert(token.contains("X-Amz-Expires=900")) + assert(token.contains("X-Amz-SignedHeaders=host")) + assert(token.contains("X-Amz-Signature=")) + + // Should contain session token for session credentials + assert(token.contains("X-Amz-Security-Token=")) + assert(token.contains(java.net.URLEncoder.encode(testSessionToken, "UTF-8"))) + } + } + + test("generateToken produces consistent output structure") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator.generateToken(credentials) + token2 <- generator.generateToken(credentials) + } yield { + // Tokens may differ due to timestamp, but structure should be consistent + assert(token1.startsWith(s"$testHostname:$testPort/?")) + assert(token2.startsWith(s"$testHostname:$testPort/?")) + + // Both should contain the same parameter names + val params1 = token1.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet + val params2 = token2.split("\\?", 2)(1).split("&").map(_.split("=")(0)).toSet + assertEquals(params1, params2) + } + } + + test("token format validation") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Parse the token URL + val parts = token.split("\\?", 2) + assertEquals(parts.length, 2) + + val hostPart = parts(0) + val queryPart = parts(1) + + assertEquals(hostPart, s"$testHostname:$testPort/") + + // Parse query parameters + val queryParams = queryPart + .split("&") + .map { param => + val kv = param.split("=", 2) + kv(0) -> URLDecoder.decode(kv(1), "UTF-8") + } + .toMap + + // Verify required parameters + assertEquals(queryParams("Action"), "connect") + assertEquals(queryParams("DBUser"), testUsername) + assertEquals(queryParams("X-Amz-Algorithm"), "AWS4-HMAC-SHA256") + assertEquals(queryParams("X-Amz-Expires"), "900") + assertEquals(queryParams("X-Amz-SignedHeaders"), "host") + + // Verify credential format + val credential = queryParams("X-Amz-Credential") + assert(credential.startsWith(testAccessKeyId)) + assert(credential.contains(testRegion)) + assert(credential.contains("rds-db")) + assert(credential.contains("aws4_request")) + + // Verify date format (should be ISO 8601 basic format) + val date = queryParams("X-Amz-Date") + assert(date.matches("""\d{8}T\d{6}Z""")) + + // Verify signature is present and non-empty + val signature = queryParams("X-Amz-Signature") + assert(signature.nonEmpty) + assert(signature.matches("[0-9a-f]+")) // Should be lowercase hex + } + } + + test("token parameter ordering") { + val generator = createGenerator() + val credentials = AwsSessionCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + sessionToken = testSessionToken, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Extract query string + val queryString = token.split("\\?", 2)(1) + val paramsBefore = queryString.split("&").filterNot(_.startsWith("X-Amz-Signature=")).toList + val paramsSorted = paramsBefore.sortBy(_.split("=")(0)) + + // Parameters should be in alphabetical order (required for AWS SigV4) + assertEquals(paramsBefore, paramsSorted) + } + } + + test("different credentials produce different signatures") { + val generator = createGenerator() + + val credentials1 = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + val credentials2 = AwsBasicCredentials( + accessKeyId = "AKIADIFFERENTKEY", + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator.generateToken(credentials1) + token2 <- generator.generateToken(credentials2) + } yield { + assert(token1 != token2) + + // Extract signatures + val sig1 = token1.split("X-Amz-Signature=")(1) + val sig2 = token2.split("X-Amz-Signature=")(1) + assert(sig1 != sig2) + } + } + + test("different regions produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = "us-east-1" + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = testUsername, + region = "us-west-2" + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + + // Verify region is included in credential scope + assert(token1.contains("us-east-1")) + assert(token2.contains("us-west-2")) + } + } + + test("different hostnames produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = "host1.region.rds.amazonaws.com", + port = testPort, + username = testUsername, + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = "host2.region.rds.amazonaws.com", + port = testPort, + username = testUsername, + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.startsWith("host1.region.rds.amazonaws.com")) + assert(token2.startsWith("host2.region.rds.amazonaws.com")) + } + } + + test("different ports produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = 3306, + username = testUsername, + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = 5432, + username = testUsername, + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.contains(":3306/")) + assert(token2.contains(":5432/")) + } + } + + test("different usernames produce different tokens") { + val generator1 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user1", + region = testRegion + ) + + val generator2 = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user2", + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + for { + token1 <- generator1.generateToken(credentials) + token2 <- generator2.generateToken(credentials) + } yield { + assert(token1 != token2) + assert(token1.contains("DBUser=user1")) + assert(token2.contains("DBUser=user2")) + } + } + + test("URL encoding is applied correctly") { + val generator = new RdsIamAuthTokenGenerator[IO]( + hostname = testHostname, + port = testPort, + username = "user with spaces@domain.com", + region = testRegion + ) + + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Username should be URL encoded + assert(token.contains("DBUser=user%20with%20spaces%40domain.com")) + + // Verify AWS-specific encoding rules + assert(!token.contains("+")) // Spaces should be %20, not + + + // Extract all parameters to verify proper encoding + val queryString = token.split("\\?", 2)(1) + val params = queryString.split("&") + + params.foreach { param => + val parts = param.split("=", 2) + if parts.length == 2 then { + val value = parts(1) + + // Verify no invalid characters remain unencoded + assert(!value.contains(" ")) + assert(!value.contains("@")) + } + } + } + } + + test("signature calculation consistency") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + val signature = token.split("X-Amz-Signature=")(1) + + // Signature should be 64 characters (SHA-256 hex) + assertEquals(signature.length, 64) + + // Signature should be lowercase hex + assert(signature.forall(c => c.isDigit || ('a' <= c && c <= 'f'))) + } + } + + test("implements AuthTokenGenerator trait") { + val generator: AuthTokenGenerator[IO] = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + assert(token.nonEmpty) + assert(token.contains(testHostname)) + } + } + + test("token contains correct service and terminator") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + // Extract credential scope from the credential parameter + val queryString = token.split("\\?", 2)(1) + val credentialParam = queryString + .split("&") + .find(_.startsWith("X-Amz-Credential=")) + .getOrElse(fail("X-Amz-Credential parameter not found")) + + val credentialValue = URLDecoder.decode(credentialParam.split("=")(1), "UTF-8") + + // Credential format: accessKeyId/date/region/service/terminator + val parts = credentialValue.split("/") + assertEquals(parts.length, 5) + assertEquals(parts(0), testAccessKeyId) + assertEquals(parts(2), testRegion) + assertEquals(parts(3), "rds-db") // SERVICE constant + assertEquals(parts(4), "aws4_request") // TERMINATOR constant + } + } + + test("token expiration is set to 900 seconds") { + val generator = createGenerator() + val credentials = AwsBasicCredentials( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + validateCredentials = true, + providerName = None, + accountId = None, + expirationTime = None + ) + + generator.generateToken(credentials).map { token => + assert(token.contains("X-Amz-Expires=900")) + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala new file mode 100644 index 000000000..6747a5487 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/SimpleHttpClientSecurityTest.scala @@ -0,0 +1,125 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.net.URI + +import scala.concurrent.duration.* + +import cats.effect.IO + +import fs2.io.net.Network + +import munit.CatsEffectSuite + +class SimpleHttpClientSecurityTest extends CatsEffectSuite: + + val client = new SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) + + test("reject URIs without scheme") { + val uri = URI.create("example.com/path") + + client.get(uri, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("URI scheme is required")) + } + } + + test("reject unsupported URI schemes") { + val uri = URI.create("ftp://example.com/path") + + client.get(uri, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("Unsupported URI scheme: ftp")) + } + } + + test("accept HTTP URIs") { + val uri = URI.create("http://httpbin.org/get") + + // This test will fail if network is not available, but validates the scheme is accepted + client.get(uri, Map.empty).attempt.map { result => + // Either succeeds or fails with network error (not scheme validation error) + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + + test("accept HTTPS URIs and use correct default port") { + val uri = URI.create("https://httpbin.org/get") + + // This test validates HTTPS scheme is accepted and would use port 443 + client.get(uri, Map.empty).attempt.map { result => + // Either succeeds or fails with network error (not scheme validation error) + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + + test("use correct default ports") { + // HTTP should default to port 80 + val httpUri = URI.create("http://example.com") + assert(httpUri.getPort == -1) // No explicit port + + // HTTPS should default to port 443 + val httpsUri = URI.create("https://example.com") + assert(httpsUri.getPort == -1) // No explicit port + + // The client should handle these correctly internally + assert(true) // This is tested implicitly by the above tests + } + + test("preserve explicit ports in URIs") { + val httpUri = URI.create("http://example.com:8080/path") + val httpsUri = URI.create("https://example.com:8443/path") + + assert(httpUri.getPort == 8080) + assert(httpsUri.getPort == 8443) + } + + test("AWS STS HTTPS endpoints should be accepted") { + // These are the actual endpoints the StsClient will use + val stsEndpoints = List( + "https://sts.us-east-1.amazonaws.com/", + "https://sts.eu-west-1.amazonaws.com/", + "https://sts.ap-northeast-1.amazonaws.com/" + ) + + stsEndpoints.foreach { endpoint => + val uri = URI.create(endpoint) + + client.get(uri, Map.empty).attempt.map { result => + // Should not fail with scheme validation error + result.fold( + error => assert(!error.getMessage.contains("Unsupported URI scheme")), + _ => assert(true) + ) + } + } + } + + test("reject HTTP requests to AWS endpoints") { + val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") + + client.get(insecureAWSEndpoint, Map.empty).attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) + } + } + + test("reject HTTP PUT requests to AWS endpoints") { + val insecureAWSEndpoint = URI.create("http://sts.us-east-1.amazonaws.com/") + + client.put(insecureAWSEndpoint, Map.empty, "test body").attempt.map { result => + assert(result.isLeft) + assert(result.left.toOption.get.getMessage.contains("AWS endpoints require HTTPS")) + } + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala new file mode 100644 index 000000000..e427e2029 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/client/StsClientTest.scala @@ -0,0 +1,137 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.client + +import java.time.Instant + +import scala.concurrent.duration.* + +import cats.effect.std.UUIDGen +import cats.effect.IO + +import fs2.io.net.Network + +import munit.CatsEffectSuite + +import ldbc.amazon.exception.{ SdkClientException, StsException } + +class StsClientTest extends CatsEffectSuite: + + // LocalStack STS endpoint (from docker-compose.yml) + private val localStackEndpoint = "http://localhost:4566" + + private val testRoleArn = "arn:aws:iam::000000000000:role/localstack-role" + private val roleSessionName = "test-session" + private val testAssumedRoleArn = s"arn:aws:sts::000000000000:assumed-role/localstack-role/$roleSessionName" + private val testWebIdentityToken = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + + // HTTP client for testing + private def httpClient: SimpleHttpClient[IO] = + SimpleHttpClient[IO](connectTimeout = 5.seconds, readTimeout = 10.seconds) + + // STS client configured for LocalStack + private def localStackStsClient: StsClient[IO] = StsClient.build(localStackEndpoint, httpClient) + + test("assumeRoleWithWebIdentity with LocalStack".flaky) { + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = testRoleArn, + webIdentityToken = testWebIdentityToken, + roleSessionName = Some(roleSessionName), + durationSeconds = Some(1800) + ) + + localStackStsClient.assumeRoleWithWebIdentity(request).map { response => + assert(response.accessKeyId.nonEmpty) + assert(response.secretAccessKey.nonEmpty) + assert(response.sessionToken.nonEmpty) + assert(response.expiration.isAfter(Instant.now())) + assertEquals(response.assumedRoleArn, testAssumedRoleArn) + } + } + + test("handle invalid role ARN") { + val client = localStackStsClient + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = "invalid-role-arn", + webIdentityToken = testWebIdentityToken + ) + + assertIOBoolean(client.assumeRoleWithWebIdentity(request).attempt.map(_.isLeft)) + } + + test("assumeRoleWithWebIdentity with auto-generated session name") { + val client = localStackStsClient + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = testRoleArn, + webIdentityToken = testWebIdentityToken + // No roleSessionName - should be auto-generated + ) + + client.assumeRoleWithWebIdentity(request).attempt.map { result => + // Test should handle auto-generation of session name + result.fold( + error => { + // Expected for LocalStack without proper STS setup + assert(error.isInstanceOf[StsException] || error.isInstanceOf[SdkClientException]) + }, + response => { + assert(response.accessKeyId.nonEmpty) + assert(response.secretAccessKey.nonEmpty) + assert(response.sessionToken.nonEmpty) + } + ) + } + } + + test("AssumeRoleWithWebIdentityResponse validation") { + val now = Instant.now() + val response = StsClient.AssumeRoleWithWebIdentityResponse( + accessKeyId = "ASIAIOSFODNN7EXAMPLE", + secretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + sessionToken = "session-token", + expiration = now.plusSeconds(3600), + assumedRoleArn = testRoleArn + ) + + assertEquals(response.accessKeyId, "ASIAIOSFODNN7EXAMPLE") + assertEquals(response.secretAccessKey, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + assertEquals(response.sessionToken, "session-token") + assertEquals(response.expiration, now.plusSeconds(3600)) + assertEquals(response.assumedRoleArn, testRoleArn) + } + + test("buildRequestBody format") { + // Test the query parameter formatting + val request = StsClient.AssumeRoleWithWebIdentityRequest( + roleArn = "arn:aws:iam::123456789012:role/TestRole", + webIdentityToken = "test-token", + roleSessionName = Some("test-session"), + durationSeconds = Some(1800) + ) + + // We can't access the private method directly, but we can test the behavior + // by ensuring our LocalStack client formats parameters correctly + val queryParams = Map( + "Action" -> "AssumeRoleWithWebIdentity", + "Version" -> "2011-06-15", + "RoleArn" -> request.roleArn, + "WebIdentityToken" -> request.webIdentityToken, + "RoleSessionName" -> request.roleSessionName.getOrElse("ldbc-session"), + "DurationSeconds" -> request.durationSeconds.getOrElse(3600).toString + ) + + queryParams.foreach { + case (key, value) => + assert(key.nonEmpty) + assert(value.nonEmpty) + } + + assert(queryParams("Action") == "AssumeRoleWithWebIdentity") + assert(queryParams("Version") == "2011-06-15") + assert(queryParams("RoleArn") == request.roleArn) + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala new file mode 100644 index 000000000..871368574 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/identity/internal/DefaultAwsCredentialsIdentityTest.scala @@ -0,0 +1,389 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.identity.internal + +import java.time.Instant + +import munit.CatsEffectSuite + +import ldbc.amazon.identity.AwsCredentialsIdentity + +class DefaultAwsCredentialsIdentityTest extends CatsEffectSuite: + + private val testAccessKeyId = "AKIAIOSFODNN7EXAMPLE" + private val testSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + private val testAccountId = "123456789012" + private val testProviderName = "test-provider" + private val testExpirationTime = Instant.parse("2024-12-06T12:00:00Z") + + test("create minimal credentials identity") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, None) + assertEquals(credentials.expirationTime, None) + assertEquals(credentials.providerName, None) + } + + test("create full credentials identity with all fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, Some(testAccountId)) + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("create credentials using factory method") { + val credentials = AwsCredentialsIdentity.create(testAccessKeyId, testSecretAccessKey) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, None) + assertEquals(credentials.expirationTime, None) + assertEquals(credentials.providerName, None) + } + + test("toString should not expose secret access key") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val stringRepresentation = credentials.toString + + // Should contain access key and other safe fields + assert(stringRepresentation.contains(testAccessKeyId)) + assert(stringRepresentation.contains(testAccountId)) + assert(stringRepresentation.contains(testProviderName)) + + // Should NOT contain secret access key + assert(!stringRepresentation.contains(testSecretAccessKey)) + + // Should contain class name + assert(stringRepresentation.contains("AwsCredentialsIdentity")) + } + + test("toString with minimal fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + val stringRepresentation = credentials.toString + + assertEquals(stringRepresentation, s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId)") + assert(!stringRepresentation.contains(testSecretAccessKey)) + } + + test("toString with partial fields") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = None + ) + + val stringRepresentation = credentials.toString + val expected = s"AwsCredentialsIdentity(accessKeyId=$testAccessKeyId, accountId=$testAccountId)" + + assertEquals(stringRepresentation, expected) + assert(!stringRepresentation.contains(testSecretAccessKey)) + } + + test("equals should work correctly for identical credentials") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials1, credentials2) + assertEquals(credentials1.hashCode(), credentials2.hashCode()) + } + + test("equals should work for credentials with different expiration and provider") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(Instant.parse("2025-01-01T00:00:00Z")), // Different expiration + providerName = Some("different-provider") // Different provider + ) + + // Should still be equal because equals only compares key fields + assertEquals(credentials1, credentials2) + assertEquals(credentials1.hashCode(), credentials2.hashCode()) + } + + test("equals should return false for different access key") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = "DIFFERENT_ACCESS_KEY", + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for different secret access key") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = "DIFFERENT_SECRET_KEY", + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for different account ID") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some("999999999999"), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should handle None vs Some account ID") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + assert(credentials1 != credentials2) + assert(credentials1.hashCode() != credentials2.hashCode()) + } + + test("equals should return false for null object") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assert(credentials != null) + assert(!credentials.equals(null)) + } + + test("equals should return false for different class") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assert(credentials.toString != "not a credentials object") + assert(!credentials.equals("not a credentials object")) + } + + test("equals should work with same concrete type") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + val credentials2 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, // Different account ID + expirationTime = None, + providerName = None + ) + + val credentials3 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), // Same account ID + expirationTime = None, + providerName = None + ) + + // Different account IDs (Some vs None) + assert(credentials1 != credentials2) + + // Same account IDs + assertEquals(credentials1, credentials3) + } + + test("equals should return false for different AwsCredentialsIdentity implementations") { + val credentials1 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = None, + providerName = None + ) + + // Factory method creates same type, but let's test interface behavior + val credentials2: AwsCredentialsIdentity = AwsCredentialsIdentity.create(testAccessKeyId, testSecretAccessKey) + + // These should be equal since factory creates DefaultAwsCredentialsIdentity + // but with different accountId (None vs Some) + assert(credentials1 != credentials2) + + // Test with exactly same fields through factory + val credentials3 = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = None, + expirationTime = None, + providerName = None + ) + + assertEquals(credentials2, credentials3) + } + + test("hashCode should be consistent") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + val hashCode1 = credentials.hashCode() + val hashCode2 = credentials.hashCode() + + assertEquals(hashCode1, hashCode2) + } + + test("implements Identity interface correctly") { + val credentials: ldbc.amazon.identity.Identity = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("implements AwsCredentialsIdentity interface correctly") { + val credentials: AwsCredentialsIdentity = DefaultAwsCredentialsIdentity( + accessKeyId = testAccessKeyId, + secretAccessKey = testSecretAccessKey, + accountId = Some(testAccountId), + expirationTime = Some(testExpirationTime), + providerName = Some(testProviderName) + ) + + assertEquals(credentials.accessKeyId, testAccessKeyId) + assertEquals(credentials.secretAccessKey, testSecretAccessKey) + assertEquals(credentials.accountId, Some(testAccountId)) + assertEquals(credentials.expirationTime, Some(testExpirationTime)) + assertEquals(credentials.providerName, Some(testProviderName)) + } + + test("handle empty strings") { + val credentials = DefaultAwsCredentialsIdentity( + accessKeyId = "", + secretAccessKey = "", + accountId = Some(""), + expirationTime = None, + providerName = Some("") + ) + + assertEquals(credentials.accessKeyId, "") + assertEquals(credentials.secretAccessKey, "") + assertEquals(credentials.accountId, Some("")) + assertEquals(credentials.providerName, Some("")) + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala new file mode 100644 index 000000000..1db0bd749 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleJsonParserTest.scala @@ -0,0 +1,186 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +import munit.CatsEffectSuite + +class SimpleJsonParserTest extends CatsEffectSuite: + + test("parse empty object") { + val result = SimpleJsonParser.parse("{}") + assert(result.isRight) + assertEquals(result.map(_.fields.size).toOption.get, 0) + } + + test("parse object with single string field") { + val json = """{"key": "value"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("key")).toOption.flatten, Some("value")) + assertEquals(result.map(_.getOrEmpty("key")).toOption.get, "value") + } + + test("parse object with multiple string fields") { + val json = + """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE", "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("AKIAIOSFODNN7EXAMPLE")) + assertEquals( + result.map(_.get("SecretAccessKey")).toOption.flatten, + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + ) + } + + test("parse object with number values") { + val json = """{"port": 3306, "timeout": 30.5}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("port")).toOption.flatten, Some("3306")) + assertEquals(result.map(_.get("timeout")).toOption.flatten, Some("30.5")) + } + + test("parse object with boolean values") { + val json = """{"enabled": true, "disabled": false}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("enabled")).toOption.flatten, Some("true")) + assertEquals(result.map(_.get("disabled")).toOption.flatten, Some("false")) + } + + test("parse object with null values") { + val json = """{"nullValue": null}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("nullValue")).toOption.flatten, None) + assertEquals(result.map(_.getOrEmpty("nullValue")).toOption.get, "") + } + + test("parse object with escaped string values") { + val json = """{"escaped": "Hello \"World\"", "newline": "Line1\nLine2"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("escaped")).toOption.flatten, Some("Hello \"World\"")) + assertEquals(result.map(_.get("newline")).toOption.flatten, Some("Line1\nLine2")) + } + + test("parse object with unicode escape sequences") { + val json = """{"unicode": "\u0041\u0042\u0043"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("unicode")).toOption.flatten, Some("ABC")) + } + + test("parse object with whitespace") { + val json = """ { "key" : "value" , "another" : 42 } """ + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("key")).toOption.flatten, Some("value")) + assertEquals(result.map(_.get("another")).toOption.flatten, Some("42")) + } + + test("parse object with nested objects (should skip)") { + val json = """{"nested": {"inner": "value"}, "simple": "text"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("nested")).toOption.flatten, Some("{...}")) + assertEquals(result.map(_.get("simple")).toOption.flatten, Some("text")) + } + + test("parse object with arrays (should skip)") { + val json = """{"array": [1, 2, 3], "simple": "text"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("array")).toOption.flatten, Some("[...]")) + assertEquals(result.map(_.get("simple")).toOption.flatten, Some("text")) + } + + test("JsonObject.require returns Right for existing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("AccessKeyId")) + assert(required.toOption.get.isRight) + assertEquals(required.toOption.get.toOption.get, "AKIAIOSFODNN7EXAMPLE") + } + + test("JsonObject.require returns Left for null value") { + val json = """{"nullValue": null}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("nullValue")) + assert(required.toOption.get.isLeft) + assertEquals(required.toOption.get.left.toOption.get, "Field 'nullValue' is null") + } + + test("JsonObject.require returns Left for missing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + val required = result.map(_.require("MissingKey")) + assert(required.toOption.get.isLeft) + assertEquals(required.toOption.get.left.toOption.get, "Required field 'MissingKey' not found") + } + + test("JsonObject.getOrEmpty returns empty string for missing key") { + val json = """{"AccessKeyId": "AKIAIOSFODNN7EXAMPLE"}""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.getOrEmpty("MissingKey")).toOption.get, "") + } + + test("fail to parse invalid JSON - not an object") { + val result = SimpleJsonParser.parse("\"just a string\"") + assert(result.isLeft) + assert(result.left.toOption.get.contains("Invalid JSON: must be an object")) + } + + test("fail to parse invalid JSON - missing closing brace") { + val result = SimpleJsonParser.parse("""{"key": "value"""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - missing colon") { + val result = SimpleJsonParser.parse("""{"key" "value"}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - unterminated string") { + val result = SimpleJsonParser.parse("""{"key": "value}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - invalid escape sequence") { + val result = SimpleJsonParser.parse("""{"key": "value\q"}""") + assert(result.isLeft) + } + + test("fail to parse invalid JSON - incomplete unicode escape") { + val result = SimpleJsonParser.parse("""{"key": "value\u00"}""") + assert(result.isLeft) + } + + test("parse complex AWS credentials response") { + val json = """{ + "AccessKeyId": "ASIAIOSFODNN7EXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", + "Token": "IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE", + "Expiration": "2024-12-06T12:34:56Z" + }""" + val result = SimpleJsonParser.parse(json) + assert(result.isRight) + assertEquals(result.map(_.get("AccessKeyId")).toOption.flatten, Some("ASIAIOSFODNN7EXAMPLE")) + assertEquals( + result.map(_.get("SecretAccessKey")).toOption.flatten, + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + ) + assertEquals( + result.map(_.get("Token")).toOption.flatten, + Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJIMEYCIQD6m6XYcCTgK8jELjQXqKE") + ) + assertEquals(result.map(_.get("Expiration")).toOption.flatten, Some("2024-12-06T12:34:56Z")) + } diff --git a/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala new file mode 100644 index 000000000..b0f346f81 --- /dev/null +++ b/module/ldbc-aws-authentication-plugin/shared/src/test/scala/ldbc/amazon/util/SimpleXmlParserTest.scala @@ -0,0 +1,247 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.amazon.util + +import munit.CatsEffectSuite + +class SimpleXmlParserTest extends CatsEffectSuite: + + test("decodeXmlEntities should decode all standard XML entities") { + val input = "&<>"'" + val expected = "&<>\"'" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) + } + + test("decodeXmlEntities should handle mixed content with entities") { + val input = "Hello & welcome to <XML> parsing "test"" + val expected = "Hello & welcome to parsing \"test\"" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), expected) + } + + test("decodeXmlEntities should handle text without entities") { + val input = "Plain text without any entities" + assertEquals(SimpleXmlParser.decodeXmlEntities(input), input) + } + + test("decodeXmlEntities should handle empty string") { + assertEquals(SimpleXmlParser.decodeXmlEntities(""), "") + } + + test("extractTagContent should extract simple tag content") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, Some("John Doe")) + } + + test("extractTagContent should extract content with whitespace trimmed") { + val xml = " John Doe " + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, Some("John Doe")) + } + + test("extractTagContent should extract content with XML entities decoded") { + val xml = "Hello & welcome to <XML>" + val result = SimpleXmlParser.extractTagContent("message", xml) + assertEquals(result, Some("Hello & welcome to ")) + } + + test("extractTagContent should return None for non-existent tag") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("email", xml) + assertEquals(result, None) + } + + test("extractTagContent should return None for malformed XML (missing end tag)") { + val xml = "John Doe" + val result = SimpleXmlParser.extractTagContent("name", xml) + assertEquals(result, None) + } + + test("extractTagContent should extract from complex XML structure") { + val xml = """ + + + John Doe + john@example.com + + + """ + val nameResult = SimpleXmlParser.extractTagContent("name", xml) + val emailResult = SimpleXmlParser.extractTagContent("email", xml) + assertEquals(nameResult, Some("John Doe")) + assertEquals(emailResult, Some("john@example.com")) + } + + test("extractTagContent should handle nested tags with same name") { + val xml = "Outer NameInner Name" + val result = SimpleXmlParser.extractTagContent("name", xml) + // Should extract the first occurrence + assertEquals(result, Some("Outer Name")) + } + + test("extractSection should extract complete XML section") { + val xml = """ + + + John Doe + john@example.com + + success + + """ + val result = SimpleXmlParser.extractSection("user", xml) + assert(result.isDefined) + val userSection = result.get + assert(userSection.contains("")) + assert(userSection.contains("")) + assert(userSection.contains("John Doe")) + assert(userSection.contains("john@example.com")) + } + + test("extractSection should return None for non-existent section") { + val xml = "success" + val result = SimpleXmlParser.extractSection("user", xml) + assertEquals(result, None) + } + + test("extractSection should return None for malformed XML section") { + val xml = "John Doe" + val result = SimpleXmlParser.extractSection("user", xml) + assertEquals(result, None) + } + + test("extractSection should handle nested sections") { + val xml = """ + + + test + + + """ + val result = SimpleXmlParser.extractSection("inner", xml) + assert(result.isDefined) + val innerSection = result.get + assertEquals(innerSection.trim, "\n test\n ") + } + + test("requireTag should return content for existing tag") { + val xml = "AKIAIOSFODNN7EXAMPLE" + val result = SimpleXmlParser.requireTag("AccessKeyId", xml, "AccessKeyId not found") + assertEquals(result, "AKIAIOSFODNN7EXAMPLE") + } + + test("requireTag should throw exception for non-existent tag") { + val xml = "John Doe" + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("email", xml, "Email tag not found") + } + } + + test("requireTag should throw exception for empty tag content") { + val xml = "" + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("name", xml, "Name cannot be empty") + } + } + + test("requireTag should throw exception for whitespace-only content") { + val xml = " " + intercept[IllegalArgumentException] { + SimpleXmlParser.requireTag("name", xml, "Name cannot be empty") + } + } + + test("requireTag should handle valid content with entities") { + val xml = "Hello & welcome" + val result = SimpleXmlParser.requireTag("message", xml, "Message not found") + assertEquals(result, "Hello & welcome") + } + + test("parse AWS STS AssumeRoleWithWebIdentity response") { + val stsResponse = """ + + + + ASIAIOSFODNN7EXAMPLE + wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY + IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3 + 2024-12-06T12:00:00Z + + amzn1.account.AF6RHO7KZU5XRVQJGXK6HB56KR2A + + AROA3XFRBF535PLBQAARL:app-session-123 + arn:aws:sts::123456789012:assumed-role/TestRole/app-session-123 + + + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + + """ + + // Extract credentials section + val credentialsSection = SimpleXmlParser.extractSection("Credentials", stsResponse) + assert(credentialsSection.isDefined) + + // Extract individual credential fields + val credentials = credentialsSection.get + val accessKeyId = SimpleXmlParser.extractTagContent("AccessKeyId", credentials) + val secretAccessKey = SimpleXmlParser.extractTagContent("SecretAccessKey", credentials) + val sessionToken = SimpleXmlParser.extractTagContent("SessionToken", credentials) + val expiration = SimpleXmlParser.extractTagContent("Expiration", credentials) + + assertEquals(accessKeyId, Some("ASIAIOSFODNN7EXAMPLE")) + assertEquals(secretAccessKey, Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY")) + assertEquals(sessionToken, Some("IQoJb3JpZ2luX2VjECoaCXVzLWVhc3QtMSJHMEUCIQDtqstfDEaRfZKFK5Z2n2CnP3")) + assertEquals(expiration, Some("2024-12-06T12:00:00Z")) + + // Test requireTag functionality + val requiredAccessKeyId = SimpleXmlParser.requireTag("AccessKeyId", credentials, "AccessKeyId is required") + assertEquals(requiredAccessKeyId, "ASIAIOSFODNN7EXAMPLE") + } + + test("parse AWS STS error response") { + val errorResponse = """ + + + Sender + InvalidParameterValue + The security token included in the request is invalid + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + """ + + val errorSection = SimpleXmlParser.extractSection("Error", errorResponse) + assert(errorSection.isDefined) + + val error = errorSection.get + val errorType = SimpleXmlParser.extractTagContent("Type", error) + val errorCode = SimpleXmlParser.extractTagContent("Code", error) + val errorMessage = SimpleXmlParser.extractTagContent("Message", error) + + assertEquals(errorType, Some("Sender")) + assertEquals(errorCode, Some("InvalidParameterValue")) + assertEquals(errorMessage, Some("The security token included in the request is invalid")) + } + + test("handle XML with special characters and entities") { + val xml = """Data contains <brackets> & "quotes" 'apostrophes'""" + val result = SimpleXmlParser.extractTagContent("message", xml) + assertEquals(result, Some("Data contains & \"quotes\" 'apostrophes'")) + } + + test("handle empty XML documents") { + val xml = "" + val result = SimpleXmlParser.extractTagContent("any", xml) + assertEquals(result, None) + } + + test("handle XML with CDATA sections") { + val xml = " and & entities]]>" + // Note: This simple parser doesn't handle CDATA, but should still extract content + val result = SimpleXmlParser.extractTagContent("data", xml) + assertEquals(result, Some(" and & entities]]>")) + } diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala index f3eb6089f..eb355d5a7 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/Connection.scala @@ -29,6 +29,8 @@ import ldbc.connector.exception.* import ldbc.connector.net.* import ldbc.connector.net.protocol.* +import ldbc.authentication.plugin.* + type Connection[F[_]] = ldbc.sql.Connection[F] object Connection: @@ -63,19 +65,22 @@ object Connection: this.default[F, Unit](host, port, user, before = unitBefore, after = unitAfter) def apply[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG) + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default[F, Unit]( host, port, @@ -89,27 +94,33 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, + defaultAuthenticationPlugin, + plugins, unitBefore, unitAfter ) def withBeforeAfter[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - before: Connection[F] => F[A], - after: (A, Connection[F]) => F[Unit], - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG) + host: String, + port: Int, + user: String, + before: Connection[F] => F[A], + after: (A, Connection[F]) => F[Unit], + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = this.default( host, port, @@ -123,27 +134,33 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, + defaultAuthenticationPlugin, + plugins, before, after ) def default[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - before: Connection[F] => F[A], - after: (A, Connection[F]) => F[Unit] + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + before: Connection[F] => F[A], + after: (A, Connection[F]) => F[Unit] ): Tracer[F] ?=> Resource[F, LdbcConnection[F]] = val logger: String => F[Unit] = s => Console[F].println(s"TLS: $s") @@ -164,29 +181,36 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, + defaultAuthenticationPlugin, + plugins, before, after ) yield connection def fromSockets[F[_]: Async: Tracer: Console: Hashing: UUIDGen, A]( - sockets: Resource[F, Socket[F]], - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - sslOptions: Option[SSLNegotiation.Options[F]], - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, - acquire: Connection[F] => F[A], - release: (A, Connection[F]) => F[Unit] + sockets: Resource[F, Socket[F]], + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + sslOptions: Option[SSLNegotiation.Options[F]], + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], + acquire: Connection[F] => F[A], + release: (A, Connection[F]) => F[Unit] ): Resource[F, LdbcConnection[F]] = + val pluginMap = plugins.map(plugin => plugin.name.toString -> plugin).toMap val capabilityFlags = defaultCapabilityFlags ++ (if database.isDefined then Set(CapabilitiesFlags.CLIENT_CONNECT_WITH_DB) else Set.empty) ++ (if sslOptions.isDefined then Set(CapabilitiesFlags.CLIENT_SSL) else Set.empty) @@ -194,7 +218,18 @@ object Connection: for given Exchange[F] <- Resource.eval(Exchange[F]) protocol <- - Protocol[F](sockets, hostInfo, debug, sslOptions, allowPublicKeyRetrieval, readTimeout, capabilityFlags) + Protocol[F]( + sockets, + hostInfo, + debug, + sslOptions, + allowPublicKeyRetrieval, + readTimeout, + capabilityFlags, + maxAllowedPacket, + defaultAuthenticationPlugin, + pluginMap + ) _ <- Resource.eval(protocol.startAuthentication(user, password.getOrElse(""))) serverVariables <- Resource.eval(protocol.serverVariables()) readOnly <- Resource.eval(Ref[F].of[Boolean](false)) @@ -220,22 +255,25 @@ object Connection: yield connection def fromSocketGroup[F[_]: Tracer: Console: Hashing: UUIDGen, A]( - socketGroup: SocketGroup[F], - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - socketOptions: List[SocketOption], - sslOptions: Option[SSLNegotiation.Options[F]], - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, - acquire: Connection[F] => F[A], - release: (A, Connection[F]) => F[Unit] + socketGroup: SocketGroup[F], + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + socketOptions: List[SocketOption], + sslOptions: Option[SSLNegotiation.Options[F]], + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = None, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: List[AuthenticationPlugin[F]], + acquire: Connection[F] => F[A], + release: (A, Connection[F]) => F[Unit] )(using ev: Async[F]): Resource[F, LdbcConnection[F]] = def fail[B](msg: String): Resource[F, B] = @@ -261,7 +299,10 @@ object Connection: allowPublicKeyRetrieval, useCursorFetch, useServerPrepStmts, + maxAllowedPacket, databaseTerm, + defaultAuthenticationPlugin, + plugins, acquire, release ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala deleted file mode 100644 index 8208b597d..000000000 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/ConnectionProvider.scala +++ /dev/null @@ -1,522 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector - -import scala.concurrent.duration.Duration - -import cats.effect.* -import cats.effect.std.Console -import cats.effect.std.UUIDGen - -import fs2.hashing.Hashing -import fs2.io.net.* - -import org.typelevel.otel4s.trace.Tracer - -import ldbc.sql.DatabaseMetaData - -import ldbc.{ Connector, Provider } -import ldbc.logging.LogHandler - -@deprecated( - "Connection creation using ConnectionProvider is now deprecated. Please use ldbc.connector.MySQLDataSource from now on. ConnectionProvider will be removed in version 0.5.x.", - "ldbc 0.4.0" -) -trait ConnectionProvider[F[_], A] extends Provider[F]: - - /** - * Update the host information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setHost("127.0.0.1") - * }}} - * - * @param host - * Host information of the database to be connected - */ - def setHost(host: String): ConnectionProvider[F, A] - - /** - * Update the port information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setPort(3306) - * }}} - * - * @param port - * Port information of the database to be connected - */ - def setPort(port: Int): ConnectionProvider[F, A] - - /** - * Update the user information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setUser("root") - * }}} - * - * @param user - * User information of the database to be connected - */ - def setUser(user: String): ConnectionProvider[F, A] - - /** - * Update the password information of the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setPassword("password") - * }}} - * - * @param password - * Password information of the database to be connected - */ - def setPassword(password: String): ConnectionProvider[F, A] - - /** - * Update the database to be connected. - * - * {{{ - * ConnectionProvider - * ... - * .setDatabase("database") - * }}} - * - * @param database - * Database name to connect to - */ - def setDatabase(database: String): ConnectionProvider[F, A] - - /** - * Update the setting of whether or not to output the log of packet communications for connection processing. - * - * Default is false. - * - * {{{ - * ConnectionProvider - * ... - * .setDebug(true) - * }}} - * - * @param debug - * Whether packet communication logs are output or not - */ - def setDebug(debug: Boolean): ConnectionProvider[F, A] - - /** - * Update whether SSL communication is used. - * - * {{{ - * ConnectionProvider - * ... - * .setSSL(SSL.Trusted) - * }}} - * - * @param ssl - * SSL set value. Changes the way certificates are operated, etc. - */ - def setSSL(ssl: SSL): ConnectionProvider[F, A] - - /** - * Update socket options for TCP/UDP sockets. - * - * {{{ - * ConnectionProvider - * ... - * .addSocketOption(SocketOption.noDelay(true)) - * }}} - * - * @param socketOption - * Socket options for TCP/UDP sockets - */ - def addSocketOption(socketOption: SocketOption): ConnectionProvider[F, A] - - /** - * Update socket options for TCP/UDP sockets. - * - * {{{ - * ConnectionProvider - * ... - * .setSocketOptions(List(SocketOption.noDelay(true))) - * }}} - * - * @param socketOptions - * List of socket options for TCP/UDP sockets - */ - def setSocketOptions(socketOptions: List[SocketOption]): ConnectionProvider[F, A] - - /** - * Update the read timeout value. - * - * {{{ - * ConnectionProvider - * ... - * .setReadTimeout(Duration.Inf) - * }}} - * - * @param readTimeout - * Read timeout value - */ - def setReadTimeout(readTimeout: Duration): ConnectionProvider[F, A] - - /** - * Update the setting of whether or not to replace the public key. - * - * {{{ - * ConnectionProvider - * ... - * .setAllowPublicKeyRetrieval(true) - * }}} - * - * @param allowPublicKeyRetrieval - * Whether to replace the public key - */ - def setAllowPublicKeyRetrieval(allowPublicKeyRetrieval: Boolean): ConnectionProvider[F, A] - - /** - * Update whether the JDBC term “catalog” or “schema” is used to refer to the database in the application. - * - * {{{ - * ConnectionProvider - * ... - * .setDatabaseTerm(DatabaseMetaData.DatabaseTerm.SCHEMA) - * }}} - * - * @param databaseTerm - * The JDBC terms [[DatabaseMetaData.DatabaseTerm.CATALOG]] and [[DatabaseMetaData.DatabaseTerm.SCHEMA]] are used to refer to the database. - */ - def setDatabaseTerm(databaseTerm: DatabaseMetaData.DatabaseTerm): ConnectionProvider[F, A] - - /** - * Update handler to output execution log of processes using connections. - * - * {{{ - * ConnectionProvider - * ... - * .setLogHandler(consoleLogger) - * }}} - * - * @param handler - * Handler for outputting logs of process execution using connections. - */ - def setLogHandler(handler: LogHandler[F]): ConnectionProvider[F, A] - - /** - * Update tracers to output metrics. - * - * {{{ - * ConnectionProvider - * ... - * .setTracer(Tracer.noop[IO]) - * }}} - * - * @param tracer - * Tracer to output metrics - */ - def setTracer(tracer: Tracer[F]): ConnectionProvider[F, A] - - /** - * Update whether to use cursor fetch for large result sets. - * - * {{{ - * ConnectionProvider - * ... - * .setUseCursorFetch(true) - * }}} - * - * @param useCursorFetch - * Whether to use cursor fetch for large result sets. - */ - def setUseCursorFetch(useCursorFetch: Boolean): ConnectionProvider[F, A] - - /** - * Update whether to use server prepared statements. - * - * {{{ - * ConnectionProvider - * ... - * .setUseServerPrepStmts(true) - * }}} - * - * @param useServerPrepStmts - * Whether to use server prepared statements. - */ - def setUseServerPrepStmts(useServerPrepStmts: Boolean): ConnectionProvider[F, A] - - /** - * Add an optional process to be executed immediately after connection is established. - * - * {{{ - * val before = ??? - * - * ConnectionProvider - * ... - * .withBefore(before) - * }}} - * - * @param before - * Arbitrary processing to be performed immediately after connection is established - * @tparam B - * Value returned after the process is executed. This value can be passed to the After process. - */ - def withBefore[B](before: Connection[F] => F[B]): ConnectionProvider[F, B] - - /** - * Add any process to be performed before disconnecting the connection. - * - * {{{ - * val after = ??? - * - * ConnectionProvider - * ... - * .withAfter(after) - * }}} - * - * @param after - * Arbitrary processing to be performed before disconnecting - */ - def withAfter(after: (A, Connection[F]) => F[Unit]): ConnectionProvider[F, A] - - /** - * Add optional processing to be performed immediately after establishing a connection and before disconnecting. - * - * The order of processing is as follows. - * - * {{{ - * 1. connection establishment - * 2. before operation - * 3. Processing using connections. Any processing used primarily in the operation of an application. - * 4. after operation - * 5. Disconnection - * }}} - * - * {{{ - * val before = ??? - * val after = ??? - * - * ConnectionProvider - * ... - * .withBeforeAfter(before, after) - * }}} - * - * @param before - * Arbitrary processing to be performed immediately after connection is established - * @param after - * Arbitrary processing to be performed before disconnecting - * @tparam B - * Value returned after the process is executed. This value can be passed to the After process. - */ - def withBeforeAfter[B](before: Connection[F] => F[B], after: (B, Connection[F]) => F[Unit]): ConnectionProvider[F, B] - -object ConnectionProvider: - - val defaultSocketOptions: List[SocketOption] = - List(SocketOption.noDelay(true)) - - @annotation.nowarn - private case class Impl[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - logHandler: Option[LogHandler[F]] = None, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - tracer: Option[Tracer[F]] = None, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - before: Option[Connection[F] => F[A]] = None, - after: Option[(A, Connection[F]) => F[Unit]] = None - ) extends ConnectionProvider[F, A]: - given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - - override def setHost(host: String): ConnectionProvider[F, A] = - this.copy(host = host) - - override def setPort(port: Int): ConnectionProvider[F, A] = - this.copy(port = port) - - override def setUser(user: String): ConnectionProvider[F, A] = - this.copy(user = user) - - override def setPassword(password: String): ConnectionProvider[F, A] = - this.copy(password = Some(password)) - - override def setDatabase(database: String): ConnectionProvider[F, A] = - this.copy(database = Some(database)) - - override def setDebug(debug: Boolean): ConnectionProvider[F, A] = - this.copy(debug = debug) - - override def setSSL(ssl: SSL): ConnectionProvider[F, A] = - this.copy(ssl = ssl) - - override def addSocketOption(socketOption: SocketOption): ConnectionProvider[F, A] = - this.copy(socketOptions = socketOptions.::(socketOption)) - - override def setSocketOptions(socketOptions: List[SocketOption]): ConnectionProvider[F, A] = - this.copy(socketOptions = socketOptions) - - override def setReadTimeout(readTimeout: Duration): ConnectionProvider[F, A] = - this.copy(readTimeout = readTimeout) - - override def setAllowPublicKeyRetrieval(allowPublicKeyRetrieval: Boolean): ConnectionProvider[F, A] = - this.copy(allowPublicKeyRetrieval = allowPublicKeyRetrieval) - - override def setDatabaseTerm(databaseTerm: DatabaseMetaData.DatabaseTerm): ConnectionProvider[F, A] = - this.copy(databaseTerm = Some(databaseTerm)) - - override def setLogHandler(handler: LogHandler[F]): ConnectionProvider[F, A] = - this.copy(logHandler = Some(handler)) - - override def setTracer(tracer: Tracer[F]): ConnectionProvider[F, A] = - this.copy(tracer = Some(tracer)) - - override def setUseCursorFetch(useCursorFetch: Boolean): ConnectionProvider[F, A] = - val useServerPrepStmts = if useCursorFetch then true else this.useServerPrepStmts - this.copy(useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts) - - override def setUseServerPrepStmts(useServerPrepStmts: Boolean): ConnectionProvider[F, A] = - this.copy(useServerPrepStmts = useServerPrepStmts) - - override def withBefore[B](before: Connection[F] => F[B]): ConnectionProvider[F, B] = - Impl( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - databaseTerm = databaseTerm, - logHandler = logHandler, - tracer = tracer, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - before = Some(before), - after = None - ) - override def withAfter(after: (A, Connection[F]) => F[Unit]): ConnectionProvider[F, A] = - this.copy(after = Some(after)) - - override def withBeforeAfter[B]( - before: Connection[F] => F[B], - after: (B, Connection[F]) => F[Unit] - ): ConnectionProvider[F, B] = - Impl( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - databaseTerm = databaseTerm, - tracer = tracer, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - before = Some(before), - after = Some(after) - ) - - override def createConnection(): Resource[F, Connection[F]] = - (before, after) match - case (Some(b), Some(a)) => - Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = a, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - case (Some(b), None) => - Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = (_, _) => Async[F].unit, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - case (None, _) => - Connection( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm - ) - - override def use[B](f: Connector[F] => F[B]): F[B] = - createConnector().use(f) - - override def createConnector(): Resource[F, Connector[F]] = - createConnection().map(conn => Connector.fromConnection(conn, logHandler)) - - @annotation.nowarn - def default[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String - ): ConnectionProvider[F, Unit] = Impl[F, Unit](host, port, user) - - @annotation.nowarn - def default[F[_]: Async: Network: Console: Hashing: UUIDGen]( - host: String, - port: Int, - user: String, - password: String, - database: String - ): ConnectionProvider[F, Unit] = - default[F](host, port, user) - .setPassword(password) - .setDatabase(database) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala index 9cab703ae..486c2dd18 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLConfig.scala @@ -416,11 +416,63 @@ trait MySQLConfig: */ def setPoolName(name: String): MySQLConfig + /** + * Gets the maximum allowed packet size for network communication with MySQL server. + * + * This setting controls the maximum size of packets that can be sent to or received from + * the MySQL server. It helps prevent memory exhaustion attacks and ensures compatibility + * with the MySQL protocol limits. + * + * The value corresponds to the MySQL server's `max_allowed_packet` system variable. + * + * @return the maximum packet size in bytes + */ + def maxAllowedPacket: Int + + /** + * Sets the maximum allowed packet size for network communication. + * + * This setting provides protection against: + * - Memory exhaustion attacks through oversized packets + * - Denial of Service (DoS) attacks via large data payloads + * - Accidental transmission of extremely large data sets + * + * @param maxAllowedPacket the maximum packet size in bytes + * @return a new MySQLConfig with the updated setting + * @throws IllegalArgumentException if the value is outside the valid range (1024 to 16,777,215) + * + * @example {{{ + * // Set conservative 64KB limit (default) + * config.setMaxAllowedPacket(65535) + * + * // Set practical 1MB limit for applications with moderate BLOB usage + * config.setMaxAllowedPacket(1048576) + * + * // Set maximum protocol limit for applications requiring large data transfers + * config.setMaxAllowedPacket(16777215) + * }}} + * + * @note The default value of 65,535 bytes (64KB) is compatible with MySQL JDBC Driver defaults + * and provides good security against packet-based attacks while accommodating most use cases. + * @note Valid range: 1,024 bytes (1KB) minimum to 16,777,215 bytes (16MB) maximum (MySQL protocol limit) + * @see [[https://dev.mysql.com/doc/refman/en/packet-too-large.html MySQL Protocol Packet Limits]] + */ + def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig + /** * Companion object for MySQLConfig providing factory methods. */ object MySQLConfig: + /** Minimum allowed packet size in bytes (1KB) */ + val MIN_PACKET_SIZE: Int = 1024 + + /** Maximum allowed packet size in bytes (16MB - MySQL protocol limit) */ + val MAX_PACKET_SIZE: Int = 16777215 + + /** Default packet size in bytes (64KB - MySQL JDBC Driver compatible) */ + val DEFAULT_PACKET_SIZE: Int = 65535 + /** Default socket options applied to all connections. */ private[ldbc] val defaultSocketOptions: List[SocketOption] = List(SocketOption.noDelay(true)) @@ -455,7 +507,8 @@ object MySQLConfig: connectionTestQuery: Option[String] = None, logPoolState: Boolean = false, poolStateLogInterval: FiniteDuration = 30.seconds, - poolName: String = "ldbc-pool" + poolName: String = "ldbc-pool", + maxAllowedPacket: Int = DEFAULT_PACKET_SIZE ) extends MySQLConfig: override def setHost(host: String): MySQLConfig = copy(host = host) @@ -491,6 +544,17 @@ object MySQLConfig: override def setLogPoolState(enabled: Boolean): MySQLConfig = copy(logPoolState = enabled) override def setPoolStateLogInterval(interval: FiniteDuration): MySQLConfig = copy(poolStateLogInterval = interval) override def setPoolName(name: String): MySQLConfig = copy(poolName = name) + override def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLConfig = { + require( + maxAllowedPacket >= MIN_PACKET_SIZE, + s"maxAllowedPacket must be at least $MIN_PACKET_SIZE bytes, but got $maxAllowedPacket" + ) + require( + maxAllowedPacket <= MAX_PACKET_SIZE, + s"maxAllowedPacket must not exceed $MAX_PACKET_SIZE bytes (MySQL protocol limit), but got $maxAllowedPacket" + ) + copy(maxAllowedPacket = maxAllowedPacket) + } /** * Creates a default MySQLConfig with standard connection parameters. diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala index c7d41c83f..22599d55c 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/MySQLDataSource.scala @@ -21,6 +21,7 @@ import ldbc.sql.DatabaseMetaData import ldbc.connector.pool.* +import ldbc.authentication.plugin.AuthenticationPlugin import ldbc.DataSource /** @@ -47,6 +48,9 @@ import ldbc.DataSource * @param tracer optional OpenTelemetry tracer for distributed tracing * @param useCursorFetch whether to use cursor-based fetching for result sets * @param useServerPrepStmts whether to use server-side prepared statements + * @param maxAllowedPacket Maximum allowed packet size for network communication in bytes. + * @param defaultAuthenticationPlugin The authentication plugin used first for communication with the server + * @param plugins Additional authentication plugins used for communication with the server * @param before optional hook to execute before a connection is acquired * @param after optional hook to execute after a connection is used * @@ -67,22 +71,25 @@ import ldbc.DataSource * }}} */ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( - host: String, - port: Int, - user: String, - password: Option[String] = None, - database: Option[String] = None, - debug: Boolean = false, - ssl: SSL = SSL.None, - socketOptions: List[SocketOption] = MySQLConfig.defaultSocketOptions, - readTimeout: Duration = Duration.Inf, - allowPublicKeyRetrieval: Boolean = false, - databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), - tracer: Option[Tracer[F]] = None, - useCursorFetch: Boolean = false, - useServerPrepStmts: Boolean = false, - before: Option[Connection[F] => F[A]] = None, - after: Option[(A, Connection[F]) => F[Unit]] = None + host: String, + port: Int, + user: String, + password: Option[String] = None, + database: Option[String] = None, + debug: Boolean = false, + ssl: SSL = SSL.None, + socketOptions: List[SocketOption] = MySQLConfig.defaultSocketOptions, + readTimeout: Duration = Duration.Inf, + allowPublicKeyRetrieval: Boolean = false, + databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), + tracer: Option[Tracer[F]] = None, + useCursorFetch: Boolean = false, + useServerPrepStmts: Boolean = false, + maxAllowedPacket: Int = MySQLConfig.DEFAULT_PACKET_SIZE, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], + before: Option[Connection[F] => F[A]] = None, + after: Option[(A, Connection[F]) => F[Unit]] = None ) extends DataSource[F]: given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) @@ -99,55 +106,64 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen (before, after) match case (Some(b), Some(a)) => Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = a, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + host = host, + port = port, + user = user, + before = b, + after = a, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( - host = host, - port = port, - user = user, - before = b, - after = (_, _) => Async[F].unit, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + host = host, + port = port, + user = user, + before = b, + after = (_, _) => Async[F].unit, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) case (None, _) => Connection( - host = host, - port = port, - user = user, - password = password, - database = database, - debug = debug, - ssl = ssl, - socketOptions = socketOptions, - readTimeout = readTimeout, - allowPublicKeyRetrieval = allowPublicKeyRetrieval, - useCursorFetch = useCursorFetch, - useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + host = host, + port = port, + user = user, + password = password, + database = database, + debug = debug, + ssl = ssl, + socketOptions = socketOptions, + readTimeout = readTimeout, + allowPublicKeyRetrieval = allowPublicKeyRetrieval, + useCursorFetch = useCursorFetch, + useServerPrepStmts = useServerPrepStmts, + maxAllowedPacket = maxAllowedPacket, + databaseTerm = databaseTerm, + defaultAuthenticationPlugin = defaultAuthenticationPlugin, + plugins = plugins ) /** Sets the hostname or IP address of the MySQL server. @@ -245,6 +261,44 @@ final case class MySQLDataSource[F[_]: Async: Network: Console: Hashing: UUIDGen def setUseServerPrepStmts(newUseServerPrepStmts: Boolean): MySQLDataSource[F, A] = copy(useServerPrepStmts = newUseServerPrepStmts) + /** Sets the maximum allowed packet size for network communication. + * + * @param maxAllowedPacket the maximum packet size in bytes (1,024 to 16,777,215) + * @return a new MySQLDataSource with the updated packet size limit + * @throws IllegalArgumentException if the value is outside the valid range + */ + def setMaxAllowedPacket(maxAllowedPacket: Int): MySQLDataSource[F, A] = { + require( + maxAllowedPacket >= MySQLConfig.MIN_PACKET_SIZE, + s"maxAllowedPacket must be at least ${ MySQLConfig.MIN_PACKET_SIZE } bytes, but got $maxAllowedPacket" + ) + require( + maxAllowedPacket <= MySQLConfig.MAX_PACKET_SIZE, + s"maxAllowedPacket must not exceed ${ MySQLConfig.MAX_PACKET_SIZE } bytes (MySQL protocol limit), but got $maxAllowedPacket" + ) + copy(maxAllowedPacket = maxAllowedPacket) + } + + /** Sets whether to authentication plugin to be used first for communication with the server. + * @param defaultAuthenticationPlugin + * The authentication plugin used first for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + def setDefaultAuthenticationPlugin(defaultAuthenticationPlugin: AuthenticationPlugin[F]): MySQLDataSource[F, A] = + copy(defaultAuthenticationPlugin = Some(defaultAuthenticationPlugin)) + + /** + * Sets whether to authentication plugin to be used for communication with the server. + * + * @param p1 + * The authentication plugin used for communication with the server + * @param pn + * List of authentication plugins used for communication with the server + * @return a new MySQLDataSource with the updated setting + */ + def setPlugins(p1: AuthenticationPlugin[F], pn: AuthenticationPlugin[F]*): MySQLDataSource[F, A] = + copy(plugins = p1 :: pn.toList) + /** * Adds a before hook that will be executed when a connection is acquired. * @@ -358,7 +412,8 @@ object MySQLDataSource: allowPublicKeyRetrieval = config.allowPublicKeyRetrieval, databaseTerm = config.databaseTerm, useCursorFetch = config.useCursorFetch, - useServerPrepStmts = config.useServerPrepStmts + useServerPrepStmts = config.useServerPrepStmts, + maxAllowedPacket = config.maxAllowedPacket ) /** @@ -452,8 +507,9 @@ object MySQLDataSource: def pooling[F[_]: Async: Network: Console: Hashing: UUIDGen]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, - tracer: Option[Tracer[F]] = None - ): Resource[F, PooledDataSource[F]] = PooledDataSource.fromConfig(config, metricsTracker, tracer) + tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] + ): Resource[F, PooledDataSource[F]] = PooledDataSource.fromConfig(config, metricsTracker, tracer, plugins) /** * Creates a pooled DataSource with connection lifecycle hooks. @@ -508,7 +564,8 @@ object MySQLDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = - PooledDataSource.fromConfigWithBeforeAfter(config, metricsTracker, tracer, before, after) + PooledDataSource.fromConfigWithBeforeAfter(config, metricsTracker, tracer, plugins, before, after) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala index 8a90704e3..a5d7a3470 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/AuthenticationPlugin.scala @@ -8,8 +8,58 @@ package ldbc.connector.authenticator import scodec.bits.ByteVector +/** + * A trait representing a MySQL authentication plugin for database connections. + * + * This trait defines the contract for various authentication mechanisms supported by MySQL, + * including traditional password-based authentication (mysql_native_password) and + * modern authentication methods like mysql_clear_password for IAM authentication. + * + * Authentication plugins are used during the MySQL handshake process to validate + * client credentials and establish secure database connections. + * + * @tparam F The effect type that wraps the authentication operations + */ +@deprecated("This plugin is deprecated. Please use ldbc.authentication.plugin.AuthenticationPlugin instead.", "0.5.0") trait AuthenticationPlugin[F[_]]: + /** + * The name of the authentication plugin as recognized by the MySQL server. + * + * Common plugin names include: + * - "mysql_native_password" for traditional SHA1-based password authentication + * - "mysql_clear_password" for plaintext password transmission over SSL + * - "caching_sha2_password" for SHA256-based password authentication + * - "mysql_old_password" for legacy MySQL authentication (deprecated) + * + * @return The plugin name string that identifies this authentication method + */ def name: String + /** + * Indicates whether this authentication plugin requires a secure (encrypted) connection. + * + * Some authentication plugins, particularly those that transmit passwords in cleartext + * (like mysql_clear_password), require SSL/TLS encryption to ensure data security. + * Traditional hashing-based plugins may optionally use encryption but don't strictly require it. + * + * @return true if SSL/TLS connection is mandatory for this plugin, false otherwise + */ + def requiresConfidentiality: Boolean + + /** + * Processes the password according to the authentication plugin's requirements. + * + * Different authentication plugins handle passwords differently: + * - mysql_native_password: Performs SHA1-based hashing with the server's scramble + * - mysql_clear_password: Returns the password as plaintext bytes (requires SSL) + * - caching_sha2_password: Performs SHA256-based hashing with salt + * + * @param password The user's password in plaintext + * @param scramble The random challenge bytes sent by the MySQL server during handshake. + * Used as salt/seed for cryptographic hashing in most authentication methods. + * May be ignored by plugins that don't use server-side challenges. + * @return The processed password data wrapped in the effect type F, ready for transmission + * to the MySQL server during authentication + */ def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala index 84b405da8..e81e41c5d 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/CachingSha2PasswordPlugin.scala @@ -12,9 +12,11 @@ import fs2.hashing.Hashing import ldbc.connector.util.Version +import ldbc.authentication.plugin.* + trait CachingSha2PasswordPlugin[F[_]] extends Sha256PasswordPlugin[F]: - override def name: String = "caching_sha2_password" + override def name: PluginName = CACHING_SHA2_PASSWORD object CachingSha2PasswordPlugin: def apply[F[_]: Hashing: Sync](version: Version): CachingSha2PasswordPlugin[F] = diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala new file mode 100644 index 000000000..50531631b --- /dev/null +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlClearPasswordPlugin.scala @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.connector.authenticator + +import java.nio.charset.StandardCharsets + +import scodec.bits.ByteVector + +import cats.effect.kernel.Sync + +@deprecated( + "This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", + "0.5.0" +) +class MysqlClearPasswordPlugin[F[_]: Sync] extends AuthenticationPlugin[F]: + + override def name: String = "mysql_clear_password" + override def requiresConfidentiality: Boolean = true + override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = + if password.isEmpty then Sync[F].pure(ByteVector.empty) + else Sync[F].delay(ByteVector(password.getBytes(StandardCharsets.UTF_8))) + +object MysqlClearPasswordPlugin: + @deprecated( + "This plugin is deprecated. Please use ldbc.authentication.plugin.MysqlClearPasswordPlugin instead.", + "0.5.0" + ) + def apply[F[_]: Sync](): MysqlClearPasswordPlugin[F] = new MysqlClearPasswordPlugin[F]() diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala index e695ba6ce..767305b22 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/MysqlNativePasswordPlugin.scala @@ -19,9 +19,12 @@ import fs2.hashing.HashAlgorithm import fs2.hashing.Hashing import fs2.Chunk +import ldbc.authentication.plugin.* + class MysqlNativePasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F]: - override def name: String = "mysql_native_password" + override def name: PluginName = MYSQL_NATIVE_PASSWORD + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala index 784a1a493..6f1df89af 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/authenticator/Sha256PasswordPlugin.scala @@ -18,12 +18,12 @@ import cats.effect.kernel.Sync import fs2.hashing.HashAlgorithm import fs2.hashing.Hashing import fs2.Chunk -trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F] with Sha256PasswordPluginPlatform[F]: - protected def xorString(from: Array[Byte], scramble: Array[Byte], length: Int): Array[Byte] = - val scrambleLength = scramble.length - (0 until length).map(pos => (from(pos) ^ scramble(pos % scrambleLength)).toByte).toArray - override def name: String = "sha256_password" +import ldbc.authentication.plugin.* + +trait Sha256PasswordPlugin[F[_]: Hashing: Sync] extends AuthenticationPlugin[F], EncryptPasswordPlugin: + override def name: PluginName = SHA256_PASSWORD + override def requiresConfidentiality: Boolean = false override def hashPassword(password: String, scramble: Array[Byte]): F[ByteVector] = if password.isEmpty then Sync[F].pure(ByteVector.empty) else diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala index 1aee80def..367749c75 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Constants.scala @@ -14,4 +14,4 @@ import ldbc.connector.util.Version object Constants: val DRIVER_NAME: String = "MySQL Connector/L" - val DRIVER_VERSION: Version = Version(0, 4, 1) + val DRIVER_VERSION: Version = Version(0, 5, 0) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala index da9b365f8..6805379ed 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/data/Parameter.scala @@ -97,7 +97,7 @@ object Parameter: def string(value: String): Parameter = new Parameter: override def columnDataType: ColumnDataType = ColumnDataType.MYSQL_TYPE_STRING - override def sql: String = ("'" + value + "'") + override def sql: String = "'" + value.replaceAll("'", "''").replaceAll("\\\\", "\\\\\\\\") + "'" override def encode: BitVector = val bytes = value.getBytes BitVector(bytes.length) |+| BitVector(copyOf(bytes, bytes.length)) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala new file mode 100644 index 000000000..7360e1c1f --- /dev/null +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/exception/PacketTooBigException.scala @@ -0,0 +1,65 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.connector.exception + +/** + * Exception thrown when a MySQL protocol packet exceeds the configured maximum packet size limit. + * + * This exception serves as a security mechanism to prevent memory exhaustion attacks and ensures + * compatibility with MySQL protocol constraints. It is thrown when the connector attempts to + * send or receive a packet that exceeds the `maxAllowedPacket` configuration setting. + * + * The `maxAllowedPacket` setting corresponds to MySQL's `max_allowed_packet` system variable + * and provides protection against: + * + * - **Memory exhaustion attacks**: Prevents attackers from sending oversized packets to consume server memory + * - **Denial of Service (DoS)**: Blocks attempts to overwhelm the system with large data payloads + * - **Accidental large data transfers**: Catches unintentional transmission of extremely large datasets + * - **Protocol violations**: Enforces MySQL protocol packet size constraints (max 16MB) + * + * @param packetLength the actual size of the packet that exceeded the limit, in bytes + * @param maxAllowed the configured maximum packet size limit, in bytes + * + * @example {{{ + * // Typical usage when packet size validation fails + * try { + * // Some operation that sends a large packet + * connection.executeUpdate(queryWithLargeData) + * } catch { + * case ex: PacketTooBigException => + * println(s"Packet too large: ${ex.getMessage}") + * // Consider increasing maxAllowedPacket or reducing data size + * } + * }}} + * + * @example {{{ + * // Configure larger packet size to handle bigger data + * val config = MySQLConfig.default + * .setMaxAllowedPacket(1048576) // 1MB limit + * + * // Or use MySQL protocol maximum + * val config = MySQLConfig.default + * .setMaxAllowedPacket(16777215) // 16MB - protocol maximum + * }}} + * + * @see [[ldbc.connector.MySQLConfig.maxAllowedPacket]] for configuration details + * @see [[ldbc.connector.MySQLConfig.setMaxAllowedPacket]] for setting the packet size limit + * @see [[https://dev.mysql.com/doc/refman/en/packet-too-large.html MySQL Protocol Packet Limits]] + * + * @note The default packet size limit is 65,535 bytes (64KB), which matches MySQL JDBC Driver defaults + * and provides good security against packet-based attacks while accommodating most use cases. + * + * @note This exception extends SQLException to maintain compatibility with JDBC error handling patterns. + * The error message includes both the actual packet size and the configured limit for debugging. + */ +class PacketTooBigException( + packetLength: Int, + maxAllowed: Int +) extends SQLException( + s"Packet for query is too large ($packetLength > $maxAllowed). " + + s"You can change the value by setting the 'maxAllowedPacket' configuration." + ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala index 4a4fcc2cc..0714e97be 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/PacketSocket.scala @@ -21,6 +21,7 @@ import fs2.io.net.Socket import fs2.Chunk import ldbc.connector.data.CapabilitiesFlags +import ldbc.connector.exception.PacketTooBigException import ldbc.connector.net.packet.* import ldbc.connector.net.packet.response.InitialPacket import ldbc.connector.net.protocol.parseHeader @@ -41,10 +42,15 @@ trait PacketSocket[F[_]]: object PacketSocket: + val DEFAULT_MAX_PACKET_SIZE = 65535 // 64KB (JDBC Driver default) + val PROTOCOL_MAX_PACKET_SIZE = 16777215 // 16MB (MySQL protocol limit) + val MIN_PACKET_SIZE = 0 + def fromBitVectorSocket[F[_]: Concurrent: Console]( - bvs: BitVectorSocket[F], - debugEnabled: Boolean, - sequenceIdRef: Ref[F, Byte] + bvs: BitVectorSocket[F], + debugEnabled: Boolean, + sequenceIdRef: Ref[F, Byte], + maxAllowedPacket: Int ): PacketSocket[F] = new PacketSocket[F]: private def debug(msg: => String): F[Unit] = @@ -56,6 +62,7 @@ object PacketSocket: (for header <- bvs.read(4) payloadSize = parseHeader(header.toByteArray) + _ <- validatePacketSize(payloadSize) payload <- bvs.read(payloadSize) response = decoder.decodeValue(payload).require _ <- @@ -94,6 +101,17 @@ object PacketSocket: _ <- sequenceIdRef.update(sequenceId => ((sequenceId + 1) % 256).toByte) yield () + private def validatePacketSize(size: Int): F[Unit] = + if size < MIN_PACKET_SIZE then + Concurrent[F].raiseError( + PacketTooBigException(size, maxAllowedPacket) + ) + else if size > maxAllowedPacket then + Concurrent[F].raiseError( + PacketTooBigException(size, maxAllowedPacket) + ) + else Concurrent[F].unit + def apply[F[_]: Console: Temporal]( debug: Boolean, sockets: Resource[F, Socket[F]], @@ -101,8 +119,9 @@ object PacketSocket: sequenceIdRef: Ref[F, Byte], initialPacketRef: Ref[F, Option[InitialPacket]], readTimeout: Duration, - capabilitiesFlags: Set[CapabilitiesFlags] + capabilitiesFlags: Set[CapabilitiesFlags], + maxAllowedPacket: Int ): Resource[F, PacketSocket[F]] = BitVectorSocket(sockets, sequenceIdRef, initialPacketRef, sslOptions, readTimeout, capabilitiesFlags).map( - fromBitVectorSocket(_, debug, sequenceIdRef) + fromBitVectorSocket(_, debug, sequenceIdRef, maxAllowedPacket) ) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala index 71141a5eb..052aa59cb 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/Protocol.scala @@ -25,7 +25,7 @@ import fs2.io.net.Socket import org.typelevel.otel4s.trace.{ Span, Tracer } import org.typelevel.otel4s.Attribute -import ldbc.connector.authenticator.* +import ldbc.connector.authenticator.{ CachingSha2PasswordPlugin, MysqlNativePasswordPlugin, Sha256PasswordPlugin } import ldbc.connector.data.* import ldbc.connector.exception.* import ldbc.connector.net.packet.* @@ -34,6 +34,8 @@ import ldbc.connector.net.packet.response.* import ldbc.connector.net.protocol.* import ldbc.connector.telemetry.* +import ldbc.authentication.plugin.* + /** * Protocol is a protocol to communicate with MySQL server. * It provides a way to authenticate, reset sequence id, and close the connection. @@ -138,14 +140,16 @@ object Protocol: private val SELECT_SERVER_VARIABLES_QUERY = "SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout" - private[ldbc] case class Impl[F[_]: Async: Tracer: Hashing]( - initialPacket: InitialPacket, - hostInfo: HostInfo, - socket: PacketSocket[F], - useSSL: Boolean = false, - allowPublicKeyRetrieval: Boolean = false, - capabilityFlags: Set[CapabilitiesFlags], - sequenceIdRef: Ref[F, Byte] + private[ldbc] case class Impl[F[_]: Async: Tracer]( + initialPacket: InitialPacket, + hostInfo: HostInfo, + socket: PacketSocket[F], + useSSL: Boolean = false, + allowPublicKeyRetrieval: Boolean = false, + capabilityFlags: Set[CapabilitiesFlags], + sequenceIdRef: Ref[F, Byte], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: MonadError[F, Throwable], ex: Exchange[F]) extends Protocol[F]: @@ -359,7 +363,7 @@ object Protocol: * Authentication method Switch Request Packet */ private def changeAuthenticationMethod(switchRequestPacket: AuthSwitchRequestPacket, password: String): F[Unit] = - determinatePlugin(switchRequestPacket.pluginName, initialPacket.serverVersion) match + determinatePlugin(switchRequestPacket.pluginName) match case Left(error) => ev.raiseError(error) *> socket.send(ComQuitPacket()) case Right(plugin: CachingSha2PasswordPlugin[F]) => for @@ -421,7 +425,7 @@ object Protocol: * Scramble buffer for authentication payload */ private def allowPublicKeyRetrievalRequest( - plugin: Sha256PasswordPlugin[F], + plugin: EncryptPasswordPlugin, password: String, scrambleBuff: Array[Byte] ): F[Unit] = @@ -488,7 +492,7 @@ object Protocol: capabilityFlags, username, hashedPassword.length.toByte +: hashedPassword.toArray, - plugin.name, + plugin.name.toString, initialPacket.characterSet, hostInfo.database ) @@ -503,16 +507,27 @@ object Protocol: Attribute("username", username) ))* ) *> ( - determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match - case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) - case Right(plugin) => handshake(plugin, username, password) *> readUntilOk(plugin, password) + defaultAuthenticationPlugin match + case Some(plugin) => + checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk( + plugin, + password + ) + case None => + determinatePlugin(initialPacket.authPlugin) match + case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) + case Right(plugin) => + checkRequiresConfidentiality(plugin, span) *> handshake(plugin, username, password) *> readUntilOk( + plugin, + password + ) ) } override def changeUser(user: String, password: String): F[Unit] = exchange[F, Unit](TelemetrySpanName.CHANGE_USER) { (span: Span[F]) => span.addAttributes(attributes*) *> ( - determinatePlugin(initialPacket.authPlugin, initialPacket.serverVersion) match + determinatePlugin(initialPacket.authPlugin) match case Left(error) => span.recordException(error) *> ev.raiseError(error) *> socket.send(ComQuitPacket()) case Right(plugin) => for @@ -532,20 +547,60 @@ object Protocol: ) } + private def checkRequiresConfidentiality(plugin: AuthenticationPlugin[F], span: Span[F]): F[Unit] = + if plugin.requiresConfidentiality && !useSSL then + val error = new SQLInvalidAuthorizationSpecException( + s"SSL connection required for plugin “${ plugin.name }”. Check if ‘ssl’ is enabled.", + hint = Some( + """// You can enable SSL. + | MySQLDataSource.build[IO](....).setSSL(SSL.Trusted) + |""".stripMargin + ) + ) + span.recordException(error) *> ev.raiseError(error) + else ev.unit + + private def determinatePlugin(pluginName: String): Either[SQLException, AuthenticationPlugin[F]] = + plugins + .get(pluginName) + .toRight( + new SQLInvalidAuthorizationSpecException( + s"Unknown authentication plugin: $pluginName", + detail = Some( + "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." + ), + hint = Some( + "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" + ) + ) + ) + def apply[F[_]: Async: Console: Tracer: Exchange: Hashing]( - sockets: Resource[F, Socket[F]], - hostInfo: HostInfo, - debug: Boolean, - sslOptions: Option[SSLNegotiation.Options[F]], - allowPublicKeyRetrieval: Boolean = false, - readTimeout: Duration, - capabilitiesFlags: Set[CapabilitiesFlags] + sockets: Resource[F, Socket[F]], + hostInfo: HostInfo, + debug: Boolean, + sslOptions: Option[SSLNegotiation.Options[F]], + allowPublicKeyRetrieval: Boolean = false, + readTimeout: Duration, + capabilitiesFlags: Set[CapabilitiesFlags], + maxAllowedPacket: Int, + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] ): Resource[F, Protocol[F]] = for sequenceIdRef <- Resource.eval(Ref[F].of[Byte](0x01)) initialPacketRef <- Resource.eval(Ref[F].of[Option[InitialPacket]](None)) packetSocket <- - PacketSocket[F](debug, sockets, sslOptions, sequenceIdRef, initialPacketRef, readTimeout, capabilitiesFlags) + PacketSocket[F]( + debug, + sockets, + sslOptions, + sequenceIdRef, + initialPacketRef, + readTimeout, + capabilitiesFlags, + maxAllowedPacket + ) protocol <- Resource.eval( fromPacketSocket( packetSocket, @@ -554,19 +609,23 @@ object Protocol: allowPublicKeyRetrieval, capabilitiesFlags, sequenceIdRef, - initialPacketRef + initialPacketRef, + defaultAuthenticationPlugin, + plugins ) ) yield protocol def fromPacketSocket[F[_]: Tracer: Exchange: Hashing]( - packetSocket: PacketSocket[F], - hostInfo: HostInfo, - sslOptions: Option[SSLNegotiation.Options[F]], - allowPublicKeyRetrieval: Boolean = false, - capabilitiesFlags: Set[CapabilitiesFlags], - sequenceIdRef: Ref[F, Byte], - initialPacketRef: Ref[F, Option[InitialPacket]] + packetSocket: PacketSocket[F], + hostInfo: HostInfo, + sslOptions: Option[SSLNegotiation.Options[F]], + allowPublicKeyRetrieval: Boolean = false, + capabilitiesFlags: Set[CapabilitiesFlags], + sequenceIdRef: Ref[F, Byte], + initialPacketRef: Ref[F, Option[InitialPacket]], + defaultAuthenticationPlugin: Option[AuthenticationPlugin[F]], + plugins: Map[String, AuthenticationPlugin[F]] )(using ev: Async[F]): F[Protocol[F]] = initialPacketRef.get.flatMap { case Some(initialPacket) => @@ -578,7 +637,13 @@ object Protocol: sslOptions.isDefined, allowPublicKeyRetrieval, capabilitiesFlags, - sequenceIdRef + sequenceIdRef, + defaultAuthenticationPlugin, + Map( + MYSQL_NATIVE_PASSWORD.toString -> MysqlNativePasswordPlugin[F](), + SHA256_PASSWORD.toString -> Sha256PasswordPlugin[F](), + CACHING_SHA2_PASSWORD.toString -> CachingSha2PasswordPlugin[F](initialPacket.serverVersion) + ) ++ plugins ) ) case None => diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala index 85cdcb723..104e9dde2 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/protocol/Authentication.scala @@ -6,14 +6,6 @@ package ldbc.connector.net.protocol -import cats.effect.kernel.Sync - -import fs2.hashing.Hashing - -import ldbc.connector.authenticator.* -import ldbc.connector.exception.* -import ldbc.connector.util.Version - /** * Protocol to handle the Authentication Phase * @@ -28,33 +20,7 @@ import ldbc.connector.util.Version * @tparam F * The effect type */ -trait Authentication[F[_]: Hashing: Sync]: - - /** - * Determine the authentication plugin. - * - * @param pluginName - * Plugin name - * @param version - * MySQL Server version - */ - protected def determinatePlugin(pluginName: String, version: Version): Either[SQLException, AuthenticationPlugin[F]] = - pluginName match - case "mysql_native_password" => Right(MysqlNativePasswordPlugin[F]()) - case "sha256_password" => Right(Sha256PasswordPlugin[F]()) - case "caching_sha2_password" => Right(CachingSha2PasswordPlugin[F](version)) - case unknown => - Left( - new SQLInvalidAuthorizationSpecException( - s"Unknown authentication plugin: $pluginName", - detail = Some( - "This error may be due to lack of support on the ldbc side or a newly added plugin on the MySQL side." - ), - hint = Some( - "Report Issues here: https://github.com/takapi327/ldbc/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=" - ) - ) - ) +trait Authentication[F[_]]: /** * Start the authentication process. diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala index 081d78da4..ef58abcf9 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/package.scala @@ -9,4 +9,3 @@ package ldbc package object connector: export ldbc.Connector export ldbc.DataSource - export ldbc.Provider diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala index 20ca1ad09..1b5806088 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/ConcurrentBag.scala @@ -127,7 +127,6 @@ object ConcurrentBag: /** * Create a new ConcurrentBag instance. * - * @param maxFiberLocalSize maximum number of items to store in fiber-local storage * @tparam F the effect type * @tparam T the type of items stored in the bag * @return a new ConcurrentBag instance @@ -248,36 +247,26 @@ object ConcurrentBag: item.getState.flatMap { state => if state == BagEntry.STATE_REMOVED then Temporal[F].unit else - // Reset state to not in use - item.setState(BagEntry.STATE_NOT_IN_USE) >> - // Check if item is in the shared list - sharedList.get.flatMap { list => - if list.exists(_ eq item) then - // Item is already in the list, handle waiters - waiters.get.flatMap { waiting => - if waiting > 0 then - // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).void - else - // No waiters, item stays in shared list + // Attempt atomic transition from IN_USE to NOT_IN_USE + item.compareAndSet(BagEntry.STATE_IN_USE, BagEntry.STATE_NOT_IN_USE).flatMap { + case true => // Successfully transitioned from IN_USE + continueRequiteProcess(item) + case false => + // Failed transition - could be NOT_IN_USE already or other state + item.getState.flatMap { currentState => + currentState match { + case BagEntry.STATE_NOT_IN_USE => + // Item already in correct state - continue with requite + continueRequiteProcess(item) + case BagEntry.STATE_REMOVED => + // Item was removed, nothing to do Temporal[F].unit + case _ => + // Handle other state conflicts + handleStateConflict(item) } - else - // Item not in list, add it with distribution - sharedList.modify { list => - val newList = distributeItem(item, list) - (newList, ()) - } >> - waiters.get.flatMap { waiting => - if waiting > 0 then - // Try to hand off directly to a waiter - handoffQueue.tryOffer(item).flatMap { - case true => Temporal[F].unit - case false => Temporal[F].unit - } - else Temporal[F].unit - } - } + } + } } } @@ -334,3 +323,47 @@ object ConcurrentBag: else val (front, back) = list.splitAt(idx - 1) front ++ (item :: back) + + private def handleStateConflict(item: T): F[Unit] = + item.getState.flatMap { + case BagEntry.STATE_REMOVED => + // Item deleted - No action required (normal termination) + Temporal[F].unit + + case BagEntry.STATE_NOT_IN_USE => + // Other fibers have already completed requite - Avoid duplicate processing + Temporal[F].unit + + case BagEntry.STATE_RESERVED => + // Temporary reservation status - Please wait a short while and try again + Temporal[F].sleep(1.milli) >> + item.compareAndSet(BagEntry.STATE_RESERVED, BagEntry.STATE_NOT_IN_USE).flatMap { + case true => continueRequiteProcess(item) + case false => handleStateConflict(item) + } + + case BagEntry.STATE_IN_USE => + // Anomaly: Still in IN_USE state - Detected state inconsistency + item.setState(BagEntry.STATE_NOT_IN_USE) >> + continueRequiteProcess(item) + + case unknownState => + Temporal[F].raiseError( + new IllegalStateException(s"Unknown bag entry state: $unknownState") + ) + } + + private def continueRequiteProcess(item: T): F[Unit] = + // Continue with the normal requite process after state conflict resolution + sharedList.get.flatMap { list => + if list.exists(_ eq item) then + waiters.get.flatMap { waiting => + if waiting > 0 then handoffQueue.tryOffer(item).void + else Temporal[F].unit + } + else + sharedList.modify { list => + val newList = distributeItem(item, list) + (newList, ()) + } + } diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala index c86b31042..c98c8ef99 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/pool/PooledDataSource.scala @@ -26,6 +26,7 @@ import ldbc.sql.DatabaseMetaData import ldbc.connector.* import ldbc.connector.exception.SQLException +import ldbc.authentication.plugin.AuthenticationPlugin import ldbc.DataSource /** @@ -184,6 +185,7 @@ object PooledDataSource: databaseTerm: Option[DatabaseMetaData.DatabaseTerm] = Some(DatabaseMetaData.DatabaseTerm.CATALOG), useCursorFetch: Boolean = false, useServerPrepStmts: Boolean = false, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None, minConnections: Int = 5, @@ -236,7 +238,8 @@ object PooledDataSource: val closeAll = state.connections.traverse_ { pooled => pooled.finalizer.attempt.flatMap { case Left(error) => - poolLogger.debug(s"Error closing connection ${ pooled.id }: ${ error.getMessage }") + poolLogger.debug(s"Error closing connection ${ pooled.id }: ${ error.getMessage }") >> + pooled.connection.close().attempt.void case Right(_) => Temporal[F].unit } @@ -598,7 +601,8 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) case (Some(b), None) => Connection.withBeforeAfter( @@ -616,7 +620,8 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) case (None, _) => Connection( @@ -632,13 +637,15 @@ object PooledDataSource: allowPublicKeyRetrieval = allowPublicKeyRetrieval, useCursorFetch = useCursorFetch, useServerPrepStmts = useServerPrepStmts, - databaseTerm = databaseTerm + databaseTerm = databaseTerm, + plugins = plugins ) private[connector] def create[F[_]: Async: Network: Console: Hashing: UUIDGen, A]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -647,7 +654,7 @@ object PooledDataSource: Resource .eval(PoolConfigValidator.validate(config)) .flatMap { _ => - createValidatedPool(config, metricsTracker, idGenerator, before, after) + createValidatedPool(config, metricsTracker, idGenerator, plugins, before, after) } .handleErrorWith { error => Resource.eval(Async[F].raiseError(error)) @@ -657,6 +664,7 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]], idGenerator: F[String], + plugins: List[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]], after: Option[(A, Connection[F]) => F[Unit]] )(using Tracer[F]): Resource[F, PooledDataSource[F]] = @@ -711,6 +719,7 @@ object PooledDataSource: keepaliveTime = config.keepaliveTime, connectionTestQuery = config.connectionTestQuery, poolLogger = poolLogger, + plugins = plugins, before = before, after = after ) @@ -758,10 +767,11 @@ object PooledDataSource: def fromConfig[F[_]: Async: Network: Console: Hashing: UUIDGen]( config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, - tracer: Option[Tracer[F]] = None + tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]] ): Resource[F, PooledDataSource[F]] = given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString)) + create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), plugins) /** * Creates a PooledDataSource with before/after hooks for each connection use. @@ -788,8 +798,9 @@ object PooledDataSource: config: MySQLConfig, metricsTracker: Option[PoolMetricsTracker[F]] = None, tracer: Option[Tracer[F]] = None, + plugins: List[AuthenticationPlugin[F]] = List.empty[AuthenticationPlugin[F]], before: Option[Connection[F] => F[A]] = None, after: Option[(A, Connection[F]) => F[Unit]] = None ): Resource[F, PooledDataSource[F]] = given Tracer[F] = tracer.getOrElse(Tracer.noop[F]) - create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), before, after) + create(config, metricsTracker, UUIDGen[F].randomUUID.map(_.toString), plugins, before, after) diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala deleted file mode 100644 index 708276607..000000000 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionProviderTest.scala +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.connector - -import scala.concurrent.duration.Duration - -import cats.effect.IO - -import fs2.io.net.SocketOption - -import org.typelevel.otel4s.trace.Tracer - -import ldbc.sql.DatabaseMetaData - -@annotation.nowarn() -class ConnectionProviderTest extends FTestPlatform: - - /** - * Default connection provider for testing - */ - def defaultProvider: ConnectionProvider[IO, Unit] = - ConnectionProvider - .default[IO]("127.0.0.1", 13306, "ldbc") - - test("ConnectionProvider#setHost") { - val provider = defaultProvider.setHost("localhost") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setPort") { - val provider = defaultProvider.setPort(3306) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setUser") { - val provider = defaultProvider.setUser("root") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setPassword") { - val provider = defaultProvider.setPassword("newpassword") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDatabase") { - val provider = defaultProvider.setDatabase("testdb") - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDebug") { - val provider = defaultProvider.setDebug(true) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setSSL") { - val provider = defaultProvider.setSSL(SSL.Trusted) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#addSocketOption") { - val option = SocketOption.noDelay(false) - val provider = defaultProvider.addSocketOption(option) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setSocketOptions") { - val options = List(SocketOption.noDelay(false)) - val provider = defaultProvider.setSocketOptions(options) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setReadTimeout") { - val provider = defaultProvider.setReadTimeout(Duration(5000, "ms")) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setAllowPublicKeyRetrieval") { - val provider = defaultProvider.setAllowPublicKeyRetrieval(true) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setDatabaseTerm") { - val provider = defaultProvider.setDatabaseTerm(DatabaseMetaData.DatabaseTerm.SCHEMA) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#setTracer") { - val tracer = Tracer.noop[IO] - val provider = defaultProvider.setTracer(tracer) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withBefore") { - val before = (conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withBefore(before) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withAfter") { - val after = (unit: Unit, conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withAfter(after) - assertNotEquals( - provider, - defaultProvider - ) - } - - test("ConnectionProvider#withBeforeAfter") { - val before = (conn: Connection[IO]) => IO.unit - val after = (unit: Unit, conn: Connection[IO]) => IO.unit - val provider = defaultProvider.withBeforeAfter(before, after) - assertNotEquals( - provider, - defaultProvider - ) - } diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala index 29ef0837b..36ca6df4f 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/ConnectionTest.scala @@ -18,6 +18,8 @@ import ldbc.sql.DatabaseMetaData import ldbc.connector.exception.* +import ldbc.authentication.plugin.MysqlClearPasswordPlugin + class ConnectionTest extends FTestPlatform: given Tracer[IO] = Tracer.noop[IO] @@ -267,6 +269,31 @@ class ConnectionTest extends FTestPlatform: assertIOBoolean(connection.use(_ => IO(true))) } + test("You can connect to the database by specifying the default authentication plugin.") { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()), + ssl = SSL.Trusted + ) + assertIOBoolean(connection.use(_ => IO(true))) + } + + test( + "Using the MySQL Clear Password Plugin when SSL is not enabled causes an SQLInvalidAuthorizationSpecException to occur." + ) { + val connection = Connection[IO]( + host = "127.0.0.1", + port = 13306, + user = "ldbc_mysql_native_user", + password = Some("ldbc_mysql_native_password"), + defaultAuthenticationPlugin = Some(MysqlClearPasswordPlugin[IO]()) + ) + interceptIO[SQLInvalidAuthorizationSpecException](connection.use(_ => IO(true))) + } + test("Catalog change will change the currently connected Catalog.") { val connection = Connection[IO]( host = "127.0.0.1", @@ -624,7 +651,7 @@ class ConnectionTest extends FTestPlatform: assertIO( connection.use(_.getMetaData().map(_.getDriverVersion())), - "ldbc-connector-0.4.1" + "ldbc-connector-0.5.0" ) } diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala index 7717485a4..e84f6a958 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLConfigTest.scala @@ -30,6 +30,7 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(config.databaseTerm, Some(DatabaseMetaData.DatabaseTerm.CATALOG)) assertEquals(config.useCursorFetch, false) assertEquals(config.useServerPrepStmts, false) + assertEquals(config.maxAllowedPacket, MySQLConfig.DEFAULT_PACKET_SIZE) } test("setHost should update host value") { @@ -130,6 +131,68 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(updated.useServerPrepStmts, true) } + test("setMaxAllowedPacket should update maxAllowedPacket value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(1048576) // 1MB + + assertEquals(updated.maxAllowedPacket, 1048576) + // Ensure other values remain unchanged + assertEquals(updated.host, config.host) + assertEquals(updated.port, config.port) + } + + test("setMaxAllowedPacket should accept minimum valid value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MIN_PACKET_SIZE) + } + + test("setMaxAllowedPacket should accept maximum valid value") { + val config = MySQLConfig.default + val updated = config.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MAX_PACKET_SIZE) + } + + test("setMaxAllowedPacket should reject values below minimum") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) + } + } + + test("setMaxAllowedPacket should reject values above maximum") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) + } + } + + test("setMaxAllowedPacket should reject zero value") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(0) + } + } + + test("setMaxAllowedPacket should reject negative values") { + val config = MySQLConfig.default + + intercept[IllegalArgumentException] { + config.setMaxAllowedPacket(-1) + } + } + + test("MySQLConfig constants should have expected values") { + assertEquals(MySQLConfig.MIN_PACKET_SIZE, 1024) + assertEquals(MySQLConfig.MAX_PACKET_SIZE, 16777215) + assertEquals(MySQLConfig.DEFAULT_PACKET_SIZE, 65535) + } + test("MySQLConfig should be immutable - original config should not change") { val original = MySQLConfig.default val originalHost = original.host @@ -167,6 +230,7 @@ class MySQLConfigTest extends FTestPlatform: .setAllowPublicKeyRetrieval(true) .setUseCursorFetch(true) .setUseServerPrepStmts(true) + .setMaxAllowedPacket(1048576) assertEquals(config.host, "localhost") assertEquals(config.port, 3307) @@ -178,6 +242,7 @@ class MySQLConfigTest extends FTestPlatform: assertEquals(config.allowPublicKeyRetrieval, true) assertEquals(config.useCursorFetch, true) assertEquals(config.useServerPrepStmts, true) + assertEquals(config.maxAllowedPacket, 1048576) } test("MySQLConfig with custom socket options") { diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala index e343c0366..e5f63e101 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/MySQLDataSourceTest.scala @@ -40,6 +40,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.databaseTerm, Some(DatabaseMetaData.DatabaseTerm.CATALOG)) assertEquals(dataSource.useCursorFetch, false) assertEquals(dataSource.useServerPrepStmts, false) + assertEquals(dataSource.maxAllowedPacket, MySQLConfig.DEFAULT_PACKET_SIZE) assertEquals(dataSource.before, None) assertEquals(dataSource.after, None) } @@ -149,6 +150,62 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(updated.useServerPrepStmts, true) } + test("setMaxAllowedPacket should update maxAllowedPacket value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(1048576) // 1MB + + assertEquals(updated.maxAllowedPacket, 1048576) + // Ensure other values remain unchanged + assertEquals(updated.host, dataSource.host) + assertEquals(updated.port, dataSource.port) + } + + test("setMaxAllowedPacket should accept minimum valid value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MIN_PACKET_SIZE) + } + + test("setMaxAllowedPacket should accept maximum valid value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + val updated = dataSource.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE) + + assertEquals(updated.maxAllowedPacket, MySQLConfig.MAX_PACKET_SIZE) + } + + test("setMaxAllowedPacket should reject values below minimum") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(MySQLConfig.MIN_PACKET_SIZE - 1) + } + } + + test("setMaxAllowedPacket should reject values above maximum") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(MySQLConfig.MAX_PACKET_SIZE + 1) + } + } + + test("setMaxAllowedPacket should reject zero value") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(0) + } + } + + test("setMaxAllowedPacket should reject negative values") { + val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") + + intercept[IllegalArgumentException] { + dataSource.setMaxAllowedPacket(-1) + } + } + test("withBefore should add a before hook and change type parameter") { val dataSource = MySQLDataSource[IO, Unit]("localhost", 3306, "root") val beforeHook: Connection[IO] => IO[String] = _ => IO.pure("before result") @@ -191,6 +248,7 @@ class MySQLDataSourceTest extends FTestPlatform: .setAllowPublicKeyRetrieval(true) .setUseCursorFetch(true) .setUseServerPrepStmts(true) + .setMaxAllowedPacket(1048576) assertEquals(dataSource.host, "127.0.0.1") assertEquals(dataSource.port, 3307) @@ -202,6 +260,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.allowPublicKeyRetrieval, true) assertEquals(dataSource.useCursorFetch, true) assertEquals(dataSource.useServerPrepStmts, true) + assertEquals(dataSource.maxAllowedPacket, 1048576) } test("MySQLDataSource.fromConfig should create DataSource from MySQLConfig") { @@ -212,6 +271,7 @@ class MySQLDataSourceTest extends FTestPlatform: .setPassword("configpass") .setDatabase("configdb") .setDebug(true) + .setMaxAllowedPacket(2097152) // 2MB val dataSource = MySQLDataSource.fromConfig[IO](config) @@ -221,6 +281,7 @@ class MySQLDataSourceTest extends FTestPlatform: assertEquals(dataSource.password, Some("configpass")) assertEquals(dataSource.database, Some("configdb")) assertEquals(dataSource.debug, true) + assertEquals(dataSource.maxAllowedPacket, 2097152) } test("MySQLDataSource.default should create DataSource with default config") { diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala index 3396fbada..ab455b982 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/authenticator/CachingSha2PasswordPluginTest.scala @@ -19,7 +19,7 @@ class CachingSha2PasswordPluginTest extends FTestPlatform: test("CachingSha2PasswordPlugin#name should return correct plugin name") { val plugin = CachingSha2PasswordPlugin[IO](Version(8, 0, 10)) - assertEquals(plugin.name, "caching_sha2_password") + assertEquals(plugin.name.toString, "caching_sha2_password") } test("CachingSha2PasswordPlugin for version >= 8.0.5 should use default transformation") { diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala index f1db74ccb..02fdb3dd7 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/data/ParameterTest.scala @@ -164,7 +164,7 @@ class ParameterTest extends FTestPlatform: // Test string with quotes val quotedStringParam = Parameter.string("test'quotes") - assertEquals(quotedStringParam.toString, "'test'quotes'") + assertEquals(quotedStringParam.toString, "'test''quotes'") // Test zero values assertEquals(Parameter.byte(0).toString, "0") diff --git a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala index 114359841..c256cf8d1 100644 --- a/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala +++ b/module/ldbc-connector/shared/src/test/scala/ldbc/connector/net/protocol/AuthenticationTest.scala @@ -8,11 +8,6 @@ package ldbc.connector.net.protocol import cats.effect.{ IO, Ref } -import fs2.hashing.Hashing - -import ldbc.connector.authenticator.* -import ldbc.connector.exception.* -import ldbc.connector.util.Version import ldbc.connector.FTestPlatform class AuthenticationTest extends Authentication[IO], FTestPlatform: @@ -50,52 +45,6 @@ class AuthenticationTest extends Authentication[IO], FTestPlatform: _ <- selectedPluginName.set(None) } yield () - // Test for determinatePlugin - test("determinatePlugin should return the correct plugin for mysql_native_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("mysql_native_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[MysqlNativePasswordPlugin[IO]]), - s"Expected MysqlNativePasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return the correct plugin for sha256_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("sha256_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[Sha256PasswordPlugin[IO]]), - s"Expected Sha256PasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return the correct plugin for caching_sha2_password") { - val version = Version(8, 0, 26) - val result = determinatePlugin("caching_sha2_password", version) - assert(result.isRight, "Expected Right but got Left") - assert( - result.exists(_.isInstanceOf[CachingSha2PasswordPlugin[IO]]), - s"Expected CachingSha2PasswordPlugin but got ${ result.map(_.getClass.getSimpleName) }" - ) - } - - test("determinatePlugin should return Left for unknown plugin") { - val version = Version(8, 0, 26) - val result = determinatePlugin("unknown_plugin", version) - assert(result.isLeft, "Expected Left but got Right") - val exception = result.left.getOrElse(throw new RuntimeException("Expected Left but got Right")) - assert( - exception.isInstanceOf[SQLInvalidAuthorizationSpecException], - s"Expected SQLInvalidAuthorizationSpecException but got ${ exception.getClass.getSimpleName }" - ) - assert( - exception.getMessage.contains("Unknown authentication plugin: unknown_plugin"), - s"Error message did not match expected. Got: ${ exception.getMessage }" - ) - } - test("startAuthentication should set the correct state") { for { _ <- resetState diff --git a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala index d8ff38a01..7633c85b8 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/KleisliInterpreter.scala @@ -140,8 +140,12 @@ class KleisliInterpreter[F[_]: Sync](logHandler: LogHandler[F]) extends Interpre override def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): Kleisli[F, Statement[F], A] = outer.onCancel(this)(fa, fin) - override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) - override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def executeQuery(sql: String): Kleisli[F, Statement[F], ResultSet[?]] = + primitive[Statement[F], ResultSet[F]](_.executeQuery(sql)).asInstanceOf[Kleisli[F, Statement[F], ResultSet[?]]] + override def executeUpdate(sql: String): Kleisli[F, Statement[F], Int] = primitive(_.executeUpdate(sql)) + override def addBatch(sql: String): Kleisli[F, Statement[F], Unit] = primitive(_.addBatch(sql)) + override def executeBatch(): Kleisli[F, Statement[F], Array[Int]] = primitive(_.executeBatch()) + override def close(): Kleisli[F, Statement[F], Unit] = primitive(_.close()) trait PreparedStatementInterpreter extends PreparedStatementOp.Visitor[[A] =>> Kleisli[F, PreparedStatement[F], A]]: diff --git a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala index 026e1f820..8a94b340b 100644 --- a/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala +++ b/module/ldbc-core/src/main/scala/ldbc/free/StatementIO.scala @@ -42,10 +42,16 @@ object StatementOp: final case class OnCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]) extends StatementOp[A]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[A] = v.onCancel(fa, fin) + final case class ExecuteQuery(sql: String) extends StatementOp[ResultSet[?]]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[ResultSet[?]] = v.executeQuery(sql) + final case class ExecuteUpdate(sql: String) extends StatementOp[Int]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[Int] = v.executeUpdate(sql) final case class AddBatch[A](sql: String) extends StatementOp[Unit]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[Unit] = v.addBatch(sql) final case class ExecuteBatch[A]() extends StatementOp[Array[Int]]: override def visit[F[_]](v: StatementOp.Visitor[F]): F[Array[Int]] = v.executeBatch() + final case class Close() extends StatementOp[Unit]: + override def visit[F[_]](v: StatementOp.Visitor[F]): F[Unit] = v.close() given Embeddable[StatementOp, Statement[?]] = new Embeddable[StatementOp, Statement[?]]: @@ -67,8 +73,11 @@ object StatementOp: def canceled: F[Unit] def onCancel[A](fa: StatementIO[A], fin: StatementIO[Unit]): F[A] - def addBatch(sql: String): F[Unit] - def executeBatch(): F[Array[Int]] + def executeQuery(sql: String): F[ResultSet[?]] + def executeUpdate(sql: String): F[Int] + def addBatch(sql: String): F[Unit] + def executeBatch(): F[Array[Int]] + def close(): F[Unit] type StatementIO[A] = Free[StatementOp, A] @@ -95,5 +104,9 @@ object StatementIO: def capturePoll[M[_]](mpoll: Poll[M]): Poll[StatementIO] = new Poll[StatementIO]: override def apply[A](fa: StatementIO[A]): StatementIO[A] = Free.liftF[StatementOp, A](StatementOp.Poll1(mpoll, fa)) - def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) - def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) + def executeQuery(sql: String): StatementIO[ResultSet[?]] = + Free.liftF[StatementOp, ResultSet[?]](StatementOp.ExecuteQuery(sql)) + def executeUpdate(sql: String): StatementIO[Int] = Free.liftF[StatementOp, Int](StatementOp.ExecuteUpdate(sql)) + def addBatch(sql: String): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.AddBatch(sql)) + def executeBatch(): StatementIO[Array[Int]] = Free.liftF[StatementOp, Array[Int]](StatementOp.ExecuteBatch()) + def close(): StatementIO[Unit] = Free.liftF[StatementOp, Unit](StatementOp.Close()) diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala index e43ffe5a8..74c92baae 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/DBIO.scala @@ -27,8 +27,47 @@ import ldbc.* import ldbc.free.* import ldbc.logging.LogEvent +/** + * DBIO (Database I/O) provides a set of combinators for building type-safe database operations + * using the Free monad pattern. It abstracts over JDBC operations and provides a functional + * interface for executing SQL queries with proper resource management and error handling. + * + * The DBIO type is a type alias for `Free[ConnectionOp, A]`, which represents a computation + * that produces a value of type `A` when run against a database connection. + * + * @example + * {{{ + * import ldbc.dsl.* + * + * val query: DBIO[List[User]] = + * sql"SELECT * FROM users WHERE age > ${ 18 }".query[User].to[List] + * + * val result: F[List[User]] = query.run(connector) + * }}} + */ object DBIO: + /** + * Binds dynamic parameters to a PreparedStatement. + * + * This method is the core mechanism for safely binding values to SQL prepared statements. + * It processes a list of dynamic parameters, extracting their encoded values and binding + * them to the appropriate positions in the PreparedStatement. + * + * The binding process: + * 1. Validates all parameters, collecting any encoding failures + * 2. Maps each encoded value to the correct PreparedStatement setter method + * 3. Handles null values by calling setNull with the appropriate SQL type + * + * @param params List of dynamic parameters to bind. Each parameter contains either + * a successfully encoded value or encoding error information + * @return A PreparedStatementIO action that performs the binding operations + * @throws IllegalArgumentException if any parameter contains encoding failures + * + * @note JDBC uses 1-based indexing for parameters, so the index is incremented by 1 + * @note The method uses pattern matching to determine the appropriate setter method + * based on the runtime type of each encoded value + */ private def paramBind(params: List[Parameter.Dynamic]): PreparedStatementIO[Unit] = val encoded = params.foldLeft(PreparedStatementIO.pure(List.empty[Encoder.Supported])) { case (acc, param) => @@ -63,6 +102,21 @@ object DBIO: } yield () + /** + * Reads a single row from a ResultSet and decodes it using the provided decoder. + * + * This method expects exactly one row in the ResultSet. It will: + * - Advance to the first row using ResultSet.next() + * - Decode the row using the provided decoder + * - Return an error if no rows are found or if decoding fails + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert the row data into type T + * @tparam T The type to decode the row into + * @return A ResultSetIO action that produces a value of type T + * @throws UnexpectedContinuation if the ResultSet is empty + * @throws DecodeFailureException if the decoder fails to decode the row + */ private def unique[T]( statement: String, decoder: Decoder[T] @@ -77,6 +131,21 @@ object DBIO: case false => ResultSetIO.raiseError(new UnexpectedContinuation("Expected ResultSet to have at least one row.")) } + /** + * Reads all rows from a ResultSet and collects them into a collection of type G[T]. + * + * This method iterates through all rows in the ResultSet, decoding each row + * and accumulating the results into a collection using the provided factory. + * It continues until ResultSet.next() returns false. + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert each row into type T + * @param factoryCompat A factory for building the target collection type G[T] + * @tparam G The collection type constructor (e.g., List, Vector, Set) + * @tparam T The element type to decode each row into + * @return A ResultSetIO action that produces a collection of decoded values + * @throws DecodeFailureException if the decoder fails on any row + */ private def whileM[G[_], T]( statement: String, decoder: Decoder[T], @@ -102,6 +171,19 @@ object DBIO: loop(builder).map(_.result()) + /** + * Reads all rows from a ResultSet and returns them as a NonEmptyList. + * + * This method ensures that at least one row exists in the ResultSet. + * If the ResultSet is empty, it raises an UnexpectedEnd error. + * + * @param statement The SQL statement being executed (used for error messages) + * @param decoder The decoder to convert each row into type A + * @tparam A The type to decode each row into + * @return A ResultSetIO action that produces a NonEmptyList of decoded values + * @throws UnexpectedEnd if the ResultSet contains no rows + * @throws DecodeFailureException if the decoder fails on any row + */ private def nel[A]( statement: String, decoder: Decoder[A] @@ -111,6 +193,35 @@ object DBIO: else ResultSetIO.pure(NonEmptyList.fromListUnsafe(results)) } + /** + * Executes a SQL query that returns exactly one row and decodes it to type A. + * + * This method: + * 1. Prepares a statement with the given SQL + * 2. Binds the provided parameters + * 3. Executes the query + * 4. Reads exactly one row from the result + * 5. Closes the prepared statement + * 6. Logs the operation (success or failure) + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert the result row to type A + * @tparam A The result type + * @return A DBIO action that produces a single value of type A + * @throws UnexpectedContinuation if the ResultSet is empty + * @throws DecodeFailureException if decoding fails + * + * @example + * {{{ + * val userId = 42 + * val query = DBIO.queryA( + * "SELECT name FROM users WHERE id = ?", + * List(Parameter.Dynamic.Success(userId)), + * Decoder[String] + * ) + * }}} + */ def queryA[A]( statement: String, params: List[Parameter.Dynamic], @@ -129,6 +240,31 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query and collects all results into a collection of type G[A]. + * + * This method allows you to specify the target collection type through the + * factory parameter. Common usage includes collecting to List, Vector, Set, etc. + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @param factory Factory for building the target collection type + * @tparam G The collection type constructor (e.g., List, Vector, Set) + * @tparam A The element type + * @return A DBIO action that produces a collection of decoded values + * @throws DecodeFailureException if decoding fails for any row + * + * @example + * {{{ + * val query = DBIO.queryTo[List, String]( + * "SELECT name FROM users WHERE age > ?", + * List(Parameter.Dynamic.Success(18)), + * Decoder[String], + * summon[FactoryCompat[String, List[String]]] + * ) + * }}} + */ def queryTo[G[_], A]( statement: String, params: List[Parameter.Dynamic], @@ -148,6 +284,33 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query that returns at most one row as an Option[A]. + * + * This method is useful when a query might return zero or one row. + * It returns: + * - Some(a) if exactly one row is found and successfully decoded + * - None if no rows are found + * - An error if more than one row is found + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert the result row to type A + * @tparam A The result type + * @return A DBIO action that produces an Option[A] + * @throws UnexpectedContinuation if more than one row is found + * @throws DecodeFailureException if decoding fails + * + * @example + * {{{ + * val query = DBIO.queryOption( + * "SELECT name FROM users WHERE email = ?", + * List(Parameter.Dynamic.Success("user@example.com")), + * Decoder[String] + * ) + * // Returns Some(name) if user exists, None otherwise + * }}} + */ def queryOption[A]( statement: String, params: List[Parameter.Dynamic], @@ -188,6 +351,31 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL query that returns at least one row as a NonEmptyList[A]. + * + * This method ensures that the query returns at least one result. + * It's useful when you need compile-time guarantees that your result + * set is not empty. + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @tparam A The result type + * @return A DBIO action that produces a NonEmptyList[A] + * @throws UnexpectedEnd if no rows are found + * @throws DecodeFailureException if decoding fails for any row + * + * @example + * {{{ + * val query = DBIO.queryNel( + * "SELECT id FROM users WHERE role = ?", + * List(Parameter.Dynamic.Success("admin")), + * Decoder[Int] + * ) + * // Fails at runtime if no admin users exist + * }}} + */ def queryNel[A]( statement: String, params: List[Parameter.Dynamic], @@ -206,6 +394,28 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a SQL update statement (INSERT, UPDATE, DELETE) with parameters. + * + * This method prepares the statement, binds parameters, executes the update, + * and returns the number of affected rows. + * + * @param statement The SQL update statement to execute + * @param params Dynamic parameters to bind to the statement + * @return A DBIO action that produces the number of affected rows + * + * @example + * {{{ + * val update = DBIO.update( + * "UPDATE users SET name = ? WHERE id = ?", + * List( + * Parameter.Dynamic.Success("New Name"), + * Parameter.Dynamic.Success(42) + * ) + * ) + * // Returns the number of updated rows + * }}} + */ def update( statement: String, params: List[Parameter.Dynamic] @@ -222,6 +432,98 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a raw SQL update statement without parameters. + * + * This method uses a Statement instead of PreparedStatement and is suitable + * for DDL operations or updates that don't require parameter binding. + * + * @param statement The SQL statement to execute + * @return A DBIO action that produces the number of affected rows + * + * @example + * {{{ + * val create = DBIO.updateRaw( + * "CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY, name VARCHAR(255))" + * ) + * }}} + * + * @note Use parameterized queries (update method) when dealing with user input + * to prevent SQL injection attacks + */ + def updateRaw(statement: String): DBIO[Int] = + (for + stmt <- ConnectionIO.createStatement() + result <- ConnectionIO.embed(stmt, StatementIO.executeUpdate(statement)) + _ <- ConnectionIO.embed(stmt, StatementIO.close()) + yield result).onError { ex => + ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) + } <* + ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + + /** + * Executes multiple SQL statements separated by semicolons as a batch. + * + * This method splits the input by semicolons and executes each statement + * as part of a batch operation. It's useful for executing multiple DDL + * or DML statements in a single round trip to the database. + * + * @param statement Multiple SQL statements separated by semicolons + * @return A DBIO action that produces an array of update counts + * + * @example + * {{{ + * val batch = DBIO.updateRaws(""" + * CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(255)); + * CREATE INDEX idx_name ON users(name); + * INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob'); + * """) + * // Returns an array with the result of each statement + * }}} + * + * @note Each statement is executed independently; failure of one doesn't + * necessarily rollback others unless in a transaction + */ + def updateRaws(statement: String): DBIO[Array[Int]] = + (for + stmt <- ConnectionIO.createStatement() + statements = statement.trim.split(";").toList + _ <- ConnectionIO.embed(stmt, statements.map(statement => StatementIO.addBatch(statement)).sequence) + result <- ConnectionIO.embed(stmt, StatementIO.executeBatch()) + _ <- ConnectionIO.embed(stmt, StatementIO.close()) + yield result).onError { ex => + ConnectionIO.performLogging(LogEvent.ProcessingFailure(statement, List.empty, ex)) + } <* + ConnectionIO.performLogging(LogEvent.Success(statement, List.empty)) + + /** + * Executes an INSERT statement and returns generated keys. + * + * This method is typically used for INSERT statements where you want to + * retrieve auto-generated keys (like auto-increment IDs). The statement + * is prepared with RETURN_GENERATED_KEYS flag. + * + * @param statement The INSERT statement to execute + * @param params Dynamic parameters to bind to the statement + * @param decoder Decoder to convert the generated key to type A + * @tparam A The type of the generated key + * @return A DBIO action that produces the generated key + * @throws UnexpectedContinuation if no keys are generated + * @throws DecodeFailureException if decoding the key fails + * + * @example + * {{{ + * val insert = DBIO.returning[Long]( + * "INSERT INTO users (name, email) VALUES (?, ?)", + * List( + * Parameter.Dynamic.Success("Alice"), + * Parameter.Dynamic.Success("alice@example.com") + * ), + * Decoder[Long] // For auto-increment ID + * ) + * // Returns the generated user ID + * }}} + */ def returning[A]( statement: String, params: List[Parameter.Dynamic], @@ -241,6 +543,27 @@ object DBIO: } <* ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value))) + /** + * Executes a batch of SQL statements. + * + * This method executes multiple statements as a batch operation, + * which is more efficient than executing them individually. + * Unlike updateRaws, this takes a list of complete statements. + * + * @param statements List of SQL statements to execute + * @return A DBIO action that produces an array of update counts + * + * @example + * {{{ + * val statements = List( + * "INSERT INTO users (name) VALUES ('Alice')", + * "INSERT INTO users (name) VALUES ('Bob')", + * "UPDATE users SET active = true WHERE name = 'Alice'" + * ) + * val batch = DBIO.sequence(statements) + * // Returns update counts for each statement + * }}} + */ def sequence( statements: List[String] ): DBIO[Array[Int]] = @@ -248,11 +571,48 @@ object DBIO: statement <- ConnectionIO.createStatement() _ <- ConnectionIO.embed(statement, statements.map(statement => StatementIO.addBatch(statement)).sequence) result <- ConnectionIO.embed(statement, StatementIO.executeBatch()) + _ <- ConnectionIO.embed(statement, StatementIO.close()) yield result).onError { ex => ConnectionIO.performLogging(LogEvent.ProcessingFailure(statements.mkString("\n"), List.empty, ex)) } <* ConnectionIO.performLogging(LogEvent.Success(statements.mkString("\n"), List.empty)) + /** + * Creates a streaming query that processes results incrementally. + * + * This method is ideal for processing large result sets without loading + * all data into memory at once. It uses JDBC fetch size to control + * how many rows are fetched from the database in each round trip. + * + * The stream is resource-safe and will automatically close the + * PreparedStatement when the stream terminates (normally or via error). + * + * @param statement The SQL query to execute + * @param params Dynamic parameters to bind to the query + * @param decoder Decoder to convert each result row to type A + * @param fetchSize Number of rows to fetch in each batch from the database + * @tparam A The result type + * @return An fs2.Stream that emits decoded values + * + * @example + * {{{ + * val stream = DBIO.stream( + * "SELECT * FROM large_table WHERE category = ?", + * List(Parameter.Dynamic.Success("books")), + * Decoder[Product], + * fetchSize = 1000 + * ) + * + * // Process results as they arrive + * stream + * .evalMap(product => processProduct(product)) + * .compile + * .drain + * .run(connector) + * }}} + * + * @note The fetch size is a hint to the JDBC driver and may be ignored + */ def stream[A]( statement: String, params: List[Parameter.Dynamic], @@ -297,12 +657,56 @@ object DBIO: } <* fs2.Stream.eval(ConnectionIO.performLogging(LogEvent.Success(statement, params.map(_.value)))) + /** + * Combines multiple DBIO actions into a single action that produces a list of results. + * + * This is equivalent to the traverse operation, executing each DBIO in sequence + * and collecting all results. + * + * @param dbios Variable number of DBIO actions to sequence + * @tparam A The result type of each DBIO + * @return A DBIO action that produces a List[A] containing all results + * + * @example + * {{{ + * val combined = DBIO.sequence( + * sql"SELECT COUNT(*) FROM users".query[Int].to[Option], + * sql"SELECT COUNT(*) FROM products".query[Int].to[Option] + * ) + * // Returns List(userCount, productCount) + * }}} + */ def sequence[A](dbios: DBIO[A]*): DBIO[List[A]] = dbios.toList.sequence - def pure[A](value: A): DBIO[A] = Free.pure(value) + /** + * Lifts a pure value into the DBIO context. + * + * @param value The value to lift + * @tparam A The type of the value + * @return A DBIO action that produces the given value + */ + def pure[A](value: A): DBIO[A] = Free.pure(value) + + /** + * Creates a failed DBIO action with the given error. + * + * @param e The error to raise + * @tparam A The expected result type (never produced) + * @return A DBIO action that fails with the given error + */ def raiseError[A](e: Throwable): DBIO[A] = ConnectionIO.raiseError(e) + /** + * Provides a Sync type class instance for DBIO. + * + * This enables DBIO to be used with any Cats Effect operations that + * require a Sync constraint, including error handling, resource management, + * and cancellation support. + * + * The implementation delegates to the underlying ConnectionIO Free monad + * operations, ensuring proper effect handling throughout the DBIO lifecycle. + */ implicit val syncDBIO: Sync[DBIO] = new Sync[DBIO]: val monad = Free.catsFreeMonadForFree[ConnectionOp] @@ -322,16 +726,67 @@ object DBIO: override def canceled: DBIO[Unit] = ConnectionIO.canceled override def onCancel[A](fa: DBIO[A], fin: DBIO[Unit]): DBIO[A] = ConnectionIO.onCancel(fa, fin) + /** + * Extension methods for DBIO values that provide various execution strategies. + * + * This class is made available through an implicit conversion and provides + * methods to run DBIO actions with different transaction semantics. + * + * @param dbio The DBIO action to execute + * @tparam A The result type of the DBIO action + */ private[ldbc] class Ops[A](dbio: DBIO[A]): + /** + * Runs the DBIO action using the provided connector. + * + * This is the basic execution method that runs the action + * with whatever transaction settings are currently active. + * + * @param connector The database connector to use + * @tparam F The effect type (e.g., IO, Task) + * @return The result wrapped in the effect type F + */ def run[F[_]](connector: Connector[F]): F[A] = connector.run(dbio) + /** + * Runs the DBIO action in read-only mode. + * + * This ensures that the connection is set to read-only before executing + * the action and restored afterwards. Useful for ensuring queries don't + * accidentally modify data. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def readOnly[F[_]](connector: Connector[F]): F[A] = connector.run(ConnectionIO.setReadOnly(true) *> dbio <* ConnectionIO.setReadOnly(false)) + /** + * Runs the DBIO action with auto-commit enabled. + * + * Each statement is automatically committed after execution. + * This is typically used for single statement operations that + * should be immediately persisted. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def commit[F[_]](connector: Connector[F]): F[A] = connector.run(ConnectionIO.setReadOnly(false) *> ConnectionIO.setAutoCommit(true) *> dbio) + /** + * Runs the DBIO action and rolls back all changes at the end. + * + * This is useful for testing or dry-run scenarios where you want + * to execute operations but not persist any changes. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + */ def rollback[F[_]](connector: Connector[F]): F[A] = connector.run( ConnectionIO.setReadOnly(false) *> @@ -341,6 +796,27 @@ object DBIO: ConnectionIO.setAutoCommit(true) ) + /** + * Runs the DBIO action within a transaction. + * + * The action is executed with auto-commit disabled. If the action + * completes successfully, changes are committed. If an error occurs, + * all changes are rolled back. + * + * @param connector The database connector to use + * @tparam F The effect type + * @return The result wrapped in the effect type F + * + * @example + * {{{ + * val transfer = for { + * _ <- sql"UPDATE accounts SET balance = balance - 100 WHERE id = 1".update + * _ <- sql"UPDATE accounts SET balance = balance + 100 WHERE id = 2".update + * } yield () + * + * transfer.transaction(connector) // Both updates succeed or both fail + * }}} + */ def transaction[F[_]](connector: Connector[F]): F[A] = connector.run( (ConnectionIO.setReadOnly(false) *> ConnectionIO.setAutoCommit(false) *> dbio) diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala deleted file mode 100644 index 9c1c852e8..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/Configuration.scala +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.time.Duration as JavaDuration - -import scala.concurrent.duration.{ Duration, FiniteDuration, * } -import scala.jdk.CollectionConverters.* - -import com.typesafe.config.* - -case class Configuration(config: Config): - def get[A](path: String)(using loader: ConfigLoader[A]): A = - loader.load(config, path) - -object Configuration: - - def load( - classLoader: ClassLoader, - directSettings: Map[String, String] - ): Configuration = - try - val directConfig: Config = ConfigFactory.parseMap(directSettings.asJava) - val config: Config = ConfigFactory.load(classLoader, directConfig) - Configuration(config) - catch case e: ConfigException => throw new Exception(e.getMessage) - - def load(): Configuration = Configuration(ConfigFactory.load()) - -trait ConfigLoader[A]: - self => - - def load(config: Config, path: String): A - - def map[B](f: A => B): ConfigLoader[B] = (config: Config, path: String) => f(self.load(config, path)) - -object ConfigLoader: - def apply[A](f: Config => String => A): ConfigLoader[A] = - (config: Config, path: String) => f(config)(path) - - given ConfigLoader[String] = ConfigLoader(_.getString) - given ConfigLoader[Int] = ConfigLoader(_.getInt) - given ConfigLoader[Long] = ConfigLoader(_.getLong) - given ConfigLoader[Number] = ConfigLoader(_.getNumber) - given ConfigLoader[Double] = ConfigLoader(_.getDouble) - given ConfigLoader[Boolean] = ConfigLoader(_.getBoolean) - given ConfigLoader[ConfigMemorySize] = ConfigLoader(_.getMemorySize) - given ConfigLoader[FiniteDuration] = ConfigLoader(_.getDuration).map(_.toNanos.nanos) - given ConfigLoader[JavaDuration] = ConfigLoader(_.getDuration) - given ConfigLoader[Duration] = ConfigLoader(config => - path => - if config.getIsNull(path) then Duration.Inf - else config.getDuration(path).toNanos.nanos - ) - - given seqBoolean: ConfigLoader[Seq[Boolean]] = - ConfigLoader(_.getBooleanList).map(_.asScala.map(_.booleanValue).toSeq) - given seqInt: ConfigLoader[Seq[Int]] = - ConfigLoader(_.getIntList).map(_.asScala.map(_.toInt).toSeq) - given seqLong: ConfigLoader[Seq[Long]] = - ConfigLoader(_.getDoubleList).map(_.asScala.map(_.longValue).toSeq) - given seqNumber: ConfigLoader[Seq[Number]] = - ConfigLoader(_.getNumberList).map(_.asScala.toSeq) - given seqDouble: ConfigLoader[Seq[Double]] = - ConfigLoader(_.getDoubleList).map(_.asScala.map(_.doubleValue).toSeq) - given seqString: ConfigLoader[Seq[String]] = - ConfigLoader(_.getStringList).map(_.asScala.toSeq) - given seqBytes: ConfigLoader[Seq[ConfigMemorySize]] = - ConfigLoader(_.getMemorySizeList).map(_.asScala.toSeq) - given seqFinite: ConfigLoader[Seq[FiniteDuration]] = - ConfigLoader(_.getDurationList).map(_.asScala.map(_.toNanos.nanos).toSeq) - given seqJavaDuration: ConfigLoader[Seq[JavaDuration]] = - ConfigLoader(_.getDurationList).map(_.asScala.toSeq) - given seqScalaDuration: ConfigLoader[Seq[Duration]] = - ConfigLoader(_.getDurationList).map(_.asScala.map(_.toNanos.nanos).toSeq) - given seqConfig: ConfigLoader[Seq[Config]] = - ConfigLoader(_.getConfigList).map(_.asScala.toSeq) - given seqConfiguration: ConfigLoader[Seq[Configuration]] = - summon[ConfigLoader[Seq[Config]]].map(_.map(Configuration(_))) - - given ConfigLoader[Config] = ConfigLoader(_.getConfig) - given ConfigLoader[ConfigObject] = ConfigLoader(_.getObject) - given ConfigLoader[ConfigList] = ConfigLoader(_.getList) - given ConfigLoader[Configuration] = summon[ConfigLoader[Config]].map(Configuration(_)) - - given [A](using loader: ConfigLoader[A]): ConfigLoader[Option[A]] with - override def load(config: Config, path: String): Option[A] = - if config.hasPath(path) && !config.getIsNull(path) then Some(loader.load(config, path)) - else None - - given [A](using loader: ConfigLoader[A]): ConfigLoader[Map[String, A]] with - override def load(config: Config, path: String): Map[String, A] = - val obj = config.getObject(path) - val conf = obj.toConfig - obj - .keySet() - .asScala - .map { key => - key -> loader.load(conf, key) - } - .toMap diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala deleted file mode 100644 index 6769ee53c..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariConfigBuilder.scala +++ /dev/null @@ -1,293 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.util.concurrent.{ ScheduledExecutorService, ThreadFactory, TimeUnit } -import java.util.Properties - -import javax.sql.DataSource as JDataSource - -import scala.concurrent.duration.Duration - -import com.zaxxer.hikari.metrics.MetricsTrackerFactory -import com.zaxxer.hikari.HikariConfig - -/** - * Build the Configuration of HikariCP. - */ -trait HikariConfigBuilder: - - protected val config: Configuration = Configuration.load() - - protected val path: String = "ldbc.hikari" - - /** List of keys to retrieve from conf file. */ - final private val CATALOG = "catalog" - final private val CONNECTION_TIMEOUT = "connection_timeout" - final private val IDLE_TIMEOUT = "idle_timeout" - final private val LEAK_DETECTION_THRESHOLD = "leak_detection_threshold" - final private val MAXIMUM_POOL_SIZE = "maximum_pool_size" - final private val MAX_LIFETIME = "max_lifetime" - final private val MINIMUM_IDLE = "minimum_idle" - final private val POOL_NAME = "pool_name" - final private val VALIDATION_TIMEOUT = "validation_timeout" - final private val ALLOW_POOL_SUSPENSION = "allow_pool_suspension" - final private val AUTO_COMMIT = "auto_commit" - final private val CONNECTION_INIT_SQL = "connection_init_sql" - final private val CONNECTION_TEST_QUERY = "connection_test_query" - final private val DATA_SOURCE_CLASSNAME = "data_source_classname" - final private val DATASOURCE_JNDI = "datasource_jndi" - final private val INITIALIZATION_FAIL_TIMEOUT = "initialization_fail_timeout" - final private val ISOLATE_INTERNAL_QUERIES = "isolate_internal_queries" - final private val JDBC_URL = "jdbc_url" - final private val READONLY = "readonly" - final private val REGISTER_MBEANS = "register_mbeans" - final private val SCHEMA = "schema" - final private val USERNAME = "username" - final private val PASSWORD = "password" - final private val DRIVER_CLASS_NAME = "driver_class_name" - final private val TRANSACTION_ISOLATION = "transaction_isolation" - - /** Number of application cores */ - private val maxCore: Int = Runtime.getRuntime.availableProcessors() - - /** - * Method to retrieve values matching any key from the conf file from the path configuration, with any type. - * - * @param func - * Process to get values from Configuration wrapped in Option - * @tparam T - * Type of value retrieved from conf file - */ - final private def readConfig[T](func: Configuration => Option[T]): Option[T] = - config.get[Option[Configuration]](path).flatMap(func(_)) - - /** Method to retrieve catalog information from the conf file. */ - private def getCatalog: Option[String] = - readConfig(_.get[Option[String]](CATALOG)) - - /** Method to retrieve connection timeout information from the conf file. */ - private def getConnectionTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](CONNECTION_TIMEOUT)) - - /** Method to retrieve idle timeout information from the conf file. */ - private def getIdleTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](IDLE_TIMEOUT)) - - /** Method to retrieve leak detection threshold information from the conf file. */ - private def getLeakDetectionThreshold: Option[Duration] = - readConfig(_.get[Option[Duration]](LEAK_DETECTION_THRESHOLD)) - - /** Method to retrieve maximum pool size information from the conf file. */ - private def getMaximumPoolSize: Option[Int] = - readConfig(_.get[Option[Int]](MAXIMUM_POOL_SIZE)) - - /** Method to retrieve max life time information from the conf file. */ - private def getMaxLifetime: Option[Duration] = - readConfig(_.get[Option[Duration]](MAX_LIFETIME)) - - /** Method to retrieve minimum idle information from the conf file. */ - private def getMinimumIdle: Option[Int] = - readConfig(_.get[Option[Int]](MINIMUM_IDLE)) - - /** Method to retrieve pool name information from the conf file. */ - private def getPoolName: Option[String] = - readConfig(_.get[Option[String]](POOL_NAME)) - - /** Method to retrieve validation timeout information from the conf file. */ - private def getValidationTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](VALIDATION_TIMEOUT)) - - /** Method to retrieve allow pool suspension information from the conf file. */ - private def getAllowPoolSuspension: Option[Boolean] = - readConfig(_.get[Option[Boolean]](ALLOW_POOL_SUSPENSION)) - - /** Method to retrieve auto commit information from the conf file. */ - private def getAutoCommit: Option[Boolean] = - readConfig(_.get[Option[Boolean]](AUTO_COMMIT)) - - /** Method to retrieve connection init sql information from the conf file. */ - private def getConnectionInitSql: Option[String] = - readConfig(_.get[Option[String]](CONNECTION_INIT_SQL)) - - /** Method to retrieve connection test query information from the conf file. */ - private def getConnectionTestQuery: Option[String] = - readConfig(_.get[Option[String]](CONNECTION_TEST_QUERY)) - - /** Method to retrieve data source class name information from the conf file. */ - private def getDataSourceClassname: Option[String] = - readConfig(_.get[Option[String]](DATA_SOURCE_CLASSNAME)) - - /** Method to retrieve data source jndi information from the conf file. */ - private def getDatasourceJndi: Option[String] = - readConfig(_.get[Option[String]](DATASOURCE_JNDI)) - - /** Method to retrieve initialization fail time out information from the conf file. */ - private def getInitializationFailTimeout: Option[Duration] = - readConfig(_.get[Option[Duration]](INITIALIZATION_FAIL_TIMEOUT)) - - /** Method to retrieve isolate internal queries information from the conf file. */ - private def getIsolateInternalQueries: Option[Boolean] = - readConfig(_.get[Option[Boolean]](ISOLATE_INTERNAL_QUERIES)) - - /** Method to retrieve jdbc url information from the conf file. */ - private def getJdbcUrl: Option[String] = - readConfig(_.get[Option[String]](JDBC_URL)) - - /** Method to retrieve readonly information from the conf file. */ - private def getReadonly: Option[Boolean] = - readConfig(_.get[Option[Boolean]](READONLY)) - - /** Method to retrieve register mbeans information from the conf file. */ - private def getRegisterMbeans: Option[Boolean] = - readConfig(_.get[Option[Boolean]](REGISTER_MBEANS)) - - /** Method to retrieve schema information from the conf file. */ - private def getSchema: Option[String] = - readConfig(_.get[Option[String]](SCHEMA)) - - /** Method to retrieve user name information from the conf file. */ - protected def getUserName: Option[String] = - readConfig(_.get[Option[String]](USERNAME)) - - /** Method to retrieve password information from the conf file. */ - protected def getPassword: Option[String] = - readConfig(_.get[Option[String]](PASSWORD)) - - /** Method to retrieve driver class name information from the conf file. */ - protected def getDriverClassName: Option[String] = - readConfig(_.get[Option[String]](DRIVER_CLASS_NAME)) - - /** Method to retrieve transaction isolation information from the conf file. */ - protected def getTransactionIsolation: Option[String] = - readConfig(_.get[Option[String]](TRANSACTION_ISOLATION)).map { v => - if v == "TRANSACTION_NONE" || v == "TRANSACTION_READ_UNCOMMITTED" || v == "TRANSACTION_READ_COMMITTED" || v == "TRANSACTION_REPEATABLE_READ" || v == "TRANSACTION_SERIALIZABLE" - then v - else - throw new IllegalArgumentException( - "TransactionIsolation must be TRANSACTION_NONE,TRANSACTION_READ_UNCOMMITTED,TRANSACTION_READ_COMMITTED,TRANSACTION_REPEATABLE_READ,TRANSACTION_SERIALIZABLE." - ) - } - - /** List of variables predefined as default settings. */ - val connectionTimeout: Long = getConnectionTimeout.getOrElse(Duration(30, TimeUnit.SECONDS)).toMillis - val idleTimeout: Long = getIdleTimeout.getOrElse(Duration(10, TimeUnit.MINUTES)).toMillis - val leakDetectionThreshold: Long = getLeakDetectionThreshold.getOrElse(Duration.Zero).toMillis - val maximumPoolSize: Int = getMaximumPoolSize.getOrElse(maxCore * 2) - val maxLifetime: Long = getMaxLifetime.getOrElse(Duration(30, TimeUnit.MINUTES)).toMillis - val minimumIdle: Int = getMinimumIdle.getOrElse(10) - val validationTimeout: Long = getValidationTimeout.getOrElse(Duration(5, TimeUnit.SECONDS)).toMillis - val allowPoolSuspension: Boolean = getAllowPoolSuspension.getOrElse(false) - val autoCommit: Boolean = getAutoCommit.getOrElse(true) - val initializationFailTimeout: Long = - getInitializationFailTimeout.getOrElse(Duration(1, TimeUnit.MILLISECONDS)).toMillis - val isolateInternalQueries: Boolean = getIsolateInternalQueries.getOrElse(false) - val readonly: Boolean = getReadonly.getOrElse(false) - val registerMbeans: Boolean = getRegisterMbeans.getOrElse(false) - - /** - * Method to generate HikariConfig based on DatabaseConfig and other settings. - * - * @param jDataSource - * Factories for connection to physical data sources - * @param dataSourceProperties - * Properties (name/value pairs) used to configure DataSource/java.sql.Driver - * @param healthCheckProperties - * Properties (name/value pairs) used to configure HealthCheck/java.sql.Driver - * @param healthCheckRegistry - * Set the HealthCheckRegistry that will be used for registration of health checks by HikariCP. - * @param metricRegistry - * Set a MetricRegistry instance to use for registration of metrics used by HikariCP. - * @param metricsTrackerFactory - * Set a MetricsTrackerFactory instance to use for registration of metrics used by HikariCP. - * @param scheduledExecutor - * Set the ScheduledExecutorService used for housekeeping. - * @param threadFactory - * Set the thread factory to be used to create threads. - */ - def build( - jDataSource: Option[JDataSource] = None, - dataSourceProperties: Option[Properties] = None, - healthCheckProperties: Option[Properties] = None, - healthCheckRegistry: Option[Object] = None, - metricRegistry: Option[Object] = None, - metricsTrackerFactory: Option[MetricsTrackerFactory] = None, - scheduledExecutor: Option[ScheduledExecutorService] = None, - threadFactory: Option[ThreadFactory] = None - ): HikariConfig = - - val hikariConfig = new HikariConfig() - - getCatalog foreach hikariConfig.setCatalog - hikariConfig.setConnectionTimeout(connectionTimeout) - hikariConfig.setIdleTimeout(idleTimeout) - hikariConfig.setLeakDetectionThreshold(leakDetectionThreshold) - hikariConfig.setMaximumPoolSize(maximumPoolSize) - hikariConfig.setMaxLifetime(maxLifetime) - hikariConfig.setMinimumIdle(minimumIdle) - hikariConfig.setValidationTimeout(validationTimeout) - hikariConfig.setAllowPoolSuspension(allowPoolSuspension) - hikariConfig.setAutoCommit(autoCommit) - hikariConfig.setInitializationFailTimeout(initializationFailTimeout) - hikariConfig.setIsolateInternalQueries(isolateInternalQueries) - hikariConfig.setReadOnly(readonly) - hikariConfig.setRegisterMbeans(registerMbeans) - - getPassword foreach hikariConfig.setPassword - getPoolName foreach hikariConfig.setPoolName - getUserName foreach hikariConfig.setUsername - getConnectionInitSql foreach hikariConfig.setConnectionInitSql - getConnectionTestQuery foreach hikariConfig.setConnectionTestQuery - getDataSourceClassname foreach hikariConfig.setDataSourceClassName - getDatasourceJndi foreach hikariConfig.setDataSourceJNDI - getDriverClassName foreach hikariConfig.setDriverClassName - getJdbcUrl foreach hikariConfig.setJdbcUrl - getSchema foreach hikariConfig.setSchema - getTransactionIsolation foreach hikariConfig.setTransactionIsolation - - jDataSource foreach hikariConfig.setDataSource - dataSourceProperties foreach hikariConfig.setDataSourceProperties - healthCheckProperties foreach hikariConfig.setHealthCheckProperties - healthCheckRegistry foreach hikariConfig.setHealthCheckRegistry - metricRegistry foreach hikariConfig.setMetricRegistry - metricsTrackerFactory foreach hikariConfig.setMetricsTrackerFactory - scheduledExecutor foreach hikariConfig.setScheduledExecutor - threadFactory foreach hikariConfig.setThreadFactory - - hikariConfig - -object HikariConfigBuilder: - - /** - * Methods for retrieving data from the ldbc default specified path. - * - * {{{ - * ldbc.hikari { - * jdbc_url = ... - * username = ... - * password = ... - * } - * }}} - */ - def default: HikariConfigBuilder = new HikariConfigBuilder {} - - /** - * Methods for retrieving data from a user-specified conf path. - * - * @param confPath - * Path of conf from which user-specified data is to be retrieved - * - * {{{ - * {user path} { - * jdbc_url = ... - * username = ... - * password = ... - * } - * }}} - */ - def from(confPath: String): HikariConfigBuilder = new HikariConfigBuilder: - override protected val path: String = confPath diff --git a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala b/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala deleted file mode 100644 index b4c94770f..000000000 --- a/module/ldbc-hikari/src/main/scala/ldbc/hikari/HikariDataSourceBuilder.scala +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import com.zaxxer.hikari.{ HikariConfig, HikariDataSource } - -import cats.effect.* -import cats.effect.implicits.* - -/** - * A model for building a database. HikariCP construction, thread pool generation for database connection, test - * connection, etc. are performed via the method. - * - * @tparam F - * the effect type. - */ -trait HikariDataSourceBuilder[F[_]: Sync] extends HikariConfigBuilder: - - /** - * Method for generating HikariDataSource with Resource. - * - * @param factory - * Process to generate HikariDataSource - */ - private def createDataSourceResource(factory: => HikariDataSource): Resource[F, HikariDataSource] = - Resource.fromAutoCloseable(Sync[F].delay(factory)) - - /** - * Method to generate Config for HikariCP. - */ - private def buildConfig(): Resource[F, HikariConfig] = - Sync[F].delay { - val hikariConfig = build() - hikariConfig.validate() - hikariConfig - }.toResource - - /** - * Method to generate DataSource from HikariCPConfig generation. - */ - def buildDataSource(): Resource[F, HikariDataSource] = - for - hikariConfig <- buildConfig() - hikariDataSource <- createDataSourceResource(new HikariDataSource(hikariConfig)) - yield hikariDataSource - - /** - * Methods for generating DataSource from user-generated HikariCPConfig. - * - * @param hikariConfig - * User-generated HikariCP Config file - */ - def buildFromConfig(hikariConfig: HikariConfig): Resource[F, HikariDataSource] = - createDataSourceResource(new HikariDataSource(hikariConfig)) - -object HikariDataSourceBuilder: - - def default[F[_]: Sync]: HikariDataSourceBuilder[F] = new HikariDataSourceBuilder[F] {} diff --git a/module/ldbc-hikari/src/test/resources/reference.conf b/module/ldbc-hikari/src/test/resources/reference.conf deleted file mode 100644 index dcda815d6..000000000 --- a/module/ldbc-hikari/src/test/resources/reference.conf +++ /dev/null @@ -1,53 +0,0 @@ -ldbc.hikari { - catalog = "ldbc" - connection_timeout = 30 s - idle_timeout = 10 m - leak_detection_threshold = 0 - maximum_pool_size = 32 - max_lifetime = 30 m - minimum_idle = 10 - pool_name = "ldbc-pool" - validation_timeout = 5 s - allow_pool_suspension = false - auto_commit = true - connection_init_sql = "select 1" - connection_test_query = "select 1" - data_source_classname = "com.mysql.cj.jdbc.Driver" - datasource_jndi = "" - initialization_fail_timeout = 1 ms - isolate_internal_queries = false - jdbc_url = "jdbc:mysql://127.0.0.1:3306/ldbc" - readonly = false - register_mbeans = false - schema = "ldbc" - username = "ldbc" - password = "mysql" - transaction_isolation = "TRANSACTION_NONE" -} - -ldbc.hikari.failure { - catalog = "ldbc" - connection_timeout = 30 s - idle_timeout = 10 m - leak_detection_threshold = 0 - maximum_pool_size = 32 - max_lifetime = 30 m - minimum_idle = 10 - pool_name = "ldbc-pool" - validation_timeout = 5 s - allow_pool_suspension = false - auto_commit = true - connection_init_sql = "select 1" - connection_test_query = "select 1" - data_source_classname = "com.mysql.cj.jdbc.Driver" - datasource_jndi = "" - initialization_fail_timeout = 1 ms - isolate_internal_queries = false - jdbc_url = "jdbc:mysql://127.0.0.1:3306/ldbc" - readonly = false - register_mbeans = false - schema = "ldbc" - username = "ldbc" - password = "mysql" - transaction_isolation = "test" -} diff --git a/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala b/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala deleted file mode 100644 index 0efebb3db..000000000 --- a/module/ldbc-hikari/src/test/scala/ldbc/hikari/HikariConfigBuilderTest.scala +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright (c) 2023-2025 by Takahiko Tominaga - * This software is licensed under the MIT License (MIT). - * For more information see LICENSE or https://opensource.org/licenses/MIT - */ - -package ldbc.hikari - -import java.util.concurrent.TimeUnit - -import scala.concurrent.duration.Duration - -import org.specs2.mutable.Specification - -object HikariConfigBuilderTest extends Specification: - - private val hikariConfig = HikariConfigBuilder.default.build() - - "Testing the HikariConfigBuilder" should { - - "The value of the specified key can be retrieved from the conf file" in { - hikariConfig.getCatalog == "ldbc" - } - - "The value of the specified key can be retrieved from the conf file" in { - hikariConfig.getJdbcUrl == "jdbc:mysql://127.0.0.1:3306/ldbc" - } - - "The connection_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionTimeout == Duration(30, TimeUnit.SECONDS).toMillis - } - - "The idle_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getIdleTimeout == Duration(10, TimeUnit.MINUTES).toMillis - } - - "The leak_detection_threshold setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getLeakDetectionThreshold == Duration.Zero.toMillis - } - - "The maximum_pool_size setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMaximumPoolSize == 32 - } - - "The max_lifetime setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMaxLifetime == Duration(30, TimeUnit.MINUTES).toMillis - } - - "The minimum_idle setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getMinimumIdle == 10 - } - - "The pool_name setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getPoolName == "ldbc-pool" - } - - "The validation_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getValidationTimeout == Duration(5, TimeUnit.SECONDS).toMillis - } - - "The connection_init_sql setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionInitSql == "select 1" - } - - "The connection_test_query setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getConnectionTestQuery == "select 1" - } - - "The connection_test_query setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getDataSourceClassName == "com.mysql.cj.jdbc.Driver" - } - - "The datasource_jndi setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getDataSourceJNDI == "" - } - - "The initialization_fail_timeout setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getInitializationFailTimeout == Duration(1, TimeUnit.MILLISECONDS).toMillis - } - - "The jdbc_url setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getJdbcUrl == "jdbc:mysql://127.0.0.1:3306/ldbc" - } - - "The schema setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getSchema == "ldbc" - } - - "The username setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getUsername == "ldbc" - } - - "The password setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getPassword == "mysql" - } - - "The transaction_isolation setting in HikariConfig matches the setting in the conf file" in { - hikariConfig.getTransactionIsolation == "TRANSACTION_NONE" - } - } - -object HikariConfigBuilderFailureTest extends Specification, HikariConfigBuilder: - override protected val path: String = "ldbc.hikari.failure" - - "Testing the HikariConfigBuilderFailureTest" should { - "IllegalArgumentException exception is raised when transaction_isolation is set to a value other than expected" in { - getTransactionIsolation must throwAn[IllegalArgumentException] - } - } diff --git a/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala new file mode 100644 index 000000000..5aec7da8a --- /dev/null +++ b/module/ldbc-zio-interop/src/main/scala/ldbc/zio/interop/package.scala @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio + +import java.util.UUID + +import cats.effect.std.UUIDGen +import cats.effect.Async + +import fs2.hashing.Hashing +import fs2.io.net.Network + +import zio.* + +package object interop: + implicit def asyncToZIO: Async[Task] = zio.interop.catz.asyncInstance + implicit def consoleToZIO: cats.effect.std.Console[Task] = cats.effect.std.Console.make[Task] + implicit def uuidGenToZIO: UUIDGen[Task] = new UUIDGen[Task]: + override def randomUUID: Task[UUID] = ZIO.attempt(UUID.randomUUID()) + implicit def hashingToZIO: Hashing[Task] = Hashing.forSync[Task] + implicit def networkToZIO: Network[Task] = Network.forAsync[Task] diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala new file mode 100644 index 000000000..14a5f4c64 --- /dev/null +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/QueryTest.scala @@ -0,0 +1,522 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio.interop + +import java.time.* + +import cats.Monad + +import ldbc.connector.* +import ldbc.connector.data.* + +import zio.* +import zio.test.* +import zio.test.Assertion.* + +object QueryTest extends ZIOSpecDefault: + + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("connector_test") + .setSSL(SSL.Trusted) + + def spec = suite("QueryTest")( + test("The client's PreparedStatement may use NULL as a parameter.") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit_null` is ?") + resultSet <- statement.setNull(1, MysqlType.BIT.jdbcType) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) + } + }, + test("Client PreparedStatement should be able to retrieve BIT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `bit`, `bit_null` FROM `all_types` WHERE `bit` = ?") + resultSet <- statement.setByte(1, 1.toByte) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.toByte, 0.toByte)))) + } + }, + test("Client PreparedStatement should be able to retrieve TINYINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinyint`, `tinyint_null` FROM `all_types` WHERE `tinyint` = ?") + resultSet <- statement.setByte(1, 127.toByte) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Byte, Byte)](resultSet.next()) { + for + v1 <- resultSet.getByte(1) + v2 <- resultSet.getByte(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((127.toByte, 0.toByte)))) + } + }, + test("Client PreparedStatement should be able to retrieve unsigned TINYINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement( + "SELECT `tinyint_unsigned`, `tinyint_unsigned_null` FROM `all_types` WHERE `tinyint_unsigned` = ?" + ) + resultSet <- statement.setShort(1, 255.toShort) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((255.toShort, 0.toShort)))) + } + }, + test("Client PreparedStatement should be able to retrieve SMALLINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `smallint`, `smallint_null` FROM `all_types` WHERE `smallint` = ?") + resultSet <- statement.setShort(1, 32767.toShort) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Short, Short)](resultSet.next()) { + for + v1 <- resultSet.getShort(1) + v2 <- resultSet.getShort(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((32767.toShort, 0.toShort)))) + } + }, + test("Client PreparedStatement should be able to retrieve unsigned SMALLINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement( + "SELECT `smallint_unsigned`, `smallint_unsigned_null` FROM `all_types` WHERE `smallint_unsigned` = ?" + ) + resultSet <- statement.setInt(1, 65535) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((65535, 0)))) + } + }, + test("Client PreparedStatement should be able to retrieve MEDIUMINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `mediumint`, `mediumint_null` FROM `all_types` WHERE `mediumint` = ?") + resultSet <- statement.setInt(1, 8388607) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((8388607, 0)))) + } + }, + test("Client PreparedStatement should be able to retrieve INT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `int`, `int_null` FROM `all_types` WHERE `int` = ?") + resultSet <- statement.setInt(1, 2147483647) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((2147483647, 0)))) + } + }, + test("Client PreparedStatement should be able to retrieve unsigned INT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `int_unsigned`, `int_unsigned_null` FROM `all_types` WHERE `int_unsigned` = ?" + ) + resultSet <- statement.setLong(1, 4294967295L) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((4294967295L, 0L)))) + } + }, + test("Client PreparedStatement should be able to retrieve BIGINT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `bigint`, `bigint_null` FROM `all_types` WHERE `bigint` = ?") + resultSet <- statement.setLong(1, 9223372036854775807L) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Long, Long)](resultSet.next()) { + for + v1 <- resultSet.getLong(1) + v2 <- resultSet.getLong(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((9223372036854775807L, 0L)))) + } + }, + test("Client PreparedStatement should be able to retrieve unsigned BIGINT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `bigint_unsigned`, `bigint_unsigned_null` FROM `all_types` WHERE `bigint_unsigned` = ?" + ) + resultSet <- statement.setString(1, "18446744073709551615") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("18446744073709551615", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve FLOAT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `float`, `float_null` FROM `all_types` WHERE `float` > ?") + resultSet <- statement.setFloat(1, 3.4f) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Float, Float)](resultSet.next()) { + for + v1 <- resultSet.getFloat(1) + v2 <- resultSet.getFloat(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((3.40282e38f, 0f)))) + } + }, + test("Client PreparedStatement should be able to retrieve DOUBLE type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `double`, `double_null` FROM `all_types` WHERE `double` = ?") + resultSet <- statement.setDouble(1, 1.7976931348623157e308) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Double, Double)](resultSet.next()) { + for + v1 <- resultSet.getDouble(1) + v2 <- resultSet.getDouble(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((1.7976931348623157e308, 0.toDouble)))) + } + }, + test("Client PreparedStatement should be able to retrieve DECIMAL type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `decimal`, `decimal_null` FROM `all_types` WHERE `decimal` = ?") + resultSet <- statement.setBigDecimal(1, BigDecimal.decimal(9999999.99)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (BigDecimal, BigDecimal)](resultSet.next()) { + for + v1 <- resultSet.getBigDecimal(1) + v2 <- resultSet.getBigDecimal(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((BigDecimal.decimal(9999999.99), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve DATE type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `date`, `date_null` FROM `all_types` WHERE `date` = ?") + resultSet <- statement.setDate(1, LocalDate.of(2020, 1, 1)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDate, LocalDate)](resultSet.next()) { + for + v1 <- resultSet.getDate(1) + v2 <- resultSet.getDate(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDate.of(2020, 1, 1), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve TIME type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `time`, `time_null` FROM `all_types` WHERE `time` = ?") + resultSet <- statement.setTime(1, LocalTime.of(12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalTime, LocalTime)](resultSet.next()) { + for + v1 <- resultSet.getTime(1) + v2 <- resultSet.getTime(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalTime.of(12, 34, 56), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve DATETIME type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `datetime`, `datetime_null` FROM `all_types` WHERE `datetime` = ?") + resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve TIMESTAMP type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `timestamp`, `timestamp_null` FROM `all_types` WHERE `timestamp` = ?") + resultSet <- statement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (LocalDateTime, LocalDateTime)](resultSet.next()) { + for + v1 <- resultSet.getTimestamp(1) + v2 <- resultSet.getTimestamp(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((LocalDateTime.of(2020, 1, 1, 12, 34, 56), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve YEAR type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `year`, `year_null` FROM `all_types` WHERE `year` = ?") + resultSet <- statement.setInt(1, 2020) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (Int, Int)](resultSet.next()) { + for + v1 <- resultSet.getInt(1) + v2 <- resultSet.getInt(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List((2020, 0)))) + } + }, + test("Client PreparedStatement should be able to retrieve CHAR type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `char`, `char_null` FROM `all_types` WHERE `char` = ?") + resultSet <- statement.setString(1, "char") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("char", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve VARCHAR type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `varchar`, `varchar_null` FROM `all_types` WHERE `varchar` = ?") + resultSet <- statement.setString(1, "varchar") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("varchar", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve BINARY type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `binary`, `binary_null` FROM `all_types` WHERE `binary` = ?") + resultSet <- + statement.setBytes(1, Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0)) *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, Array[Byte])](resultSet.next()) { + for + v1 <- resultSet.getBytes(1) + v2 <- resultSet.getBytes(2) + yield (v1.mkString(":"), v2) + } + yield assert(result)(equalTo(List((Array[Byte](98, 105, 110, 97, 114, 121, 0, 0, 0, 0).mkString(":"), null)))) + } + }, + test("Client PreparedStatement should be able to retrieve VARBINARY type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `varbinary`, `varbinary_null` FROM `all_types` WHERE `varbinary` = ?") + resultSet <- statement.setString(1, "varbinary") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("varbinary", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve TINYBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinyblob`, `tinyblob_null` FROM `all_types` WHERE `tinyblob` = ?") + resultSet <- statement.setString(1, "tinyblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("tinyblob", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve BLOB type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `blob`, `blob_null` FROM `all_types` WHERE `blob` = ?") + resultSet <- statement.setString(1, "blob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("blob", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve MEDIUMBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `mediumblob`, `mediumblob_null` FROM `all_types` WHERE `mediumblob` = ?" + ) + resultSet <- statement.setString(1, "mediumblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("mediumblob", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve LONGBLOB type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `longblob`, `longblob_null` FROM `all_types` WHERE `longblob` = ?") + resultSet <- statement.setString(1, "longblob") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("longblob", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve TINYTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `tinytext`, `tinytext_null` FROM `all_types` WHERE `tinytext` = ?") + resultSet <- statement.setString(1, "tinytext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("tinytext", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve TEXT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `text`, `text_null` FROM `all_types` WHERE `text` = ?") + resultSet <- statement.setString(1, "text") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("text", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve MEDIUMTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement( + "SELECT `mediumtext`, `mediumtext_null` FROM `all_types` WHERE `mediumtext` = ?" + ) + resultSet <- statement.setString(1, "mediumtext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("mediumtext", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve LONGTEXT type records") { + datasource.getConnection.use { conn => + for + statement <- + conn.prepareStatement("SELECT `longtext`, `longtext_null` FROM `all_types` WHERE `longtext` = ?") + resultSet <- statement.setString(1, "longtext") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("longtext", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve ENUM type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `enum`, `enum_null` FROM `all_types` WHERE `enum` = ?") + resultSet <- statement.setString(1, "a") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("a", null)))) + } + }, + test("Client PreparedStatement should be able to retrieve SET type records") { + datasource.getConnection.use { conn => + for + statement <- conn.prepareStatement("SELECT `set`, `set_null` FROM `all_types` WHERE `set` = ?") + resultSet <- statement.setString(1, "a,b") *> statement.executeQuery() + result <- Monad[Task].whileM[List, (String, String)](resultSet.next()) { + for + v1 <- resultSet.getString(1) + v2 <- resultSet.getString(2) + yield (v1, v2) + } + yield assert(result)(equalTo(List(("a,b", null)))) + } + } + ) diff --git a/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala new file mode 100644 index 000000000..3b571ea32 --- /dev/null +++ b/module/ldbc-zio-interop/src/test/scala/ldbc/zio/interop/UpdateTest.scala @@ -0,0 +1,268 @@ +/** + * Copyright (c) 2023-2025 by Takahiko Tominaga + * This software is licensed under the MIT License (MIT). + * For more information see LICENSE or https://opensource.org/licenses/MIT + */ + +package ldbc.zio.interop + +import java.time.* + +import ldbc.connector.* +import ldbc.connector.data.* + +import zio.* +import zio.test.* +import zio.test.Assertion.* + +object UpdateTest extends ZIOSpecDefault: + + private val datasource = + MySQLDataSource + .build[Task]("127.0.0.1", 13306, "ldbc") + .setPassword("password") + .setDatabase("connector_test") + .setSSL(SSL.Trusted) + + def spec = suite("UpdateTest")( + test("Boolean values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_boolean_table`(`c1` BOOLEAN NOT NULL, `c2` BOOLEAN NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_boolean_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setBoolean(1, true) *> preparedStatement + .setNull(2, MysqlType.BOOLEAN.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_boolean_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Byte values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_byte_table`(`c1` BIT NOT NULL, `c2` BIT NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_byte_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setByte(1, 1.toByte) *> preparedStatement + .setNull(2, MysqlType.BIT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_byte_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Short values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_short_table`(`c1` TINYINT NOT NULL, `c2` TINYINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_short_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setShort(1, 1.toShort) *> preparedStatement + .setNull(2, MysqlType.TINYINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_short_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Int values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_int_table`(`c1` SMALLINT NOT NULL, `c2` SMALLINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_int_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setInt(1, 1) *> preparedStatement.setNull( + 2, + MysqlType.SMALLINT.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_int_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Long values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_long_table`(`c1` BIGINT NOT NULL, `c2` BIGINT NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_long_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setLong(1, Long.MaxValue) *> preparedStatement + .setNull(2, MysqlType.BIGINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_long_table`") + yield assert(count)(equalTo(1)) + } + }, + test("BigInt values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- + statement.executeUpdate( + "CREATE TABLE `client_statement_bigint_table`(`c1` BIGINT unsigned NOT NULL, `c2` BIGINT unsigned NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bigint_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setString(1, "18446744073709551615") *> preparedStatement + .setNull(2, MysqlType.BIGINT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bigint_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Float values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- + statement.executeUpdate("CREATE TABLE `client_statement_float_table`(`c1` FLOAT NOT NULL, `c2` FLOAT NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_float_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setFloat(1, 1.1f) *> preparedStatement + .setNull(2, MysqlType.FLOAT.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_float_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Double values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_double_table`(`c1` DOUBLE NOT NULL, `c2` DOUBLE NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_double_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setDouble(1, 1.1) *> preparedStatement + .setNull(2, MysqlType.DOUBLE.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_double_table`") + yield assert(count)(equalTo(1)) + } + }, + test("BigDecimal values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_bigdecimal_table`(`c1` DECIMAL NOT NULL, `c2` DECIMAL NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bigdecimal_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setBigDecimal(1, BigDecimal.decimal(1.1)) *> preparedStatement + .setNull(2, MysqlType.DECIMAL.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bigdecimal_table`") + yield assert(count)(equalTo(1)) + } + }, + test("String values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_string_table`(`c1` VARCHAR(255) NOT NULL, `c2` VARCHAR(255) NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_string_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setString(1, "test") *> preparedStatement + .setNull(2, MysqlType.VARCHAR.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_string_table`") + yield assert(count)(equalTo(1)) + } + }, + test("Array[Byte] values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_bytes_table`(`c1` BINARY(10) NOT NULL, `c2` BINARY NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_bytes_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setBytes(1, Array[Byte](98, 105, 110, 97, 114, 121)) *> preparedStatement + .setNull(2, MysqlType.BINARY.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_bytes_table`") + yield assert(count)(equalTo(1)) + } + }, + test("java.time.LocalTime values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_time_table`(`c1` TIME NOT NULL, `c2` TIME NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_time_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setTime(1, LocalTime.of(12, 34, 56)) *> preparedStatement.setNull( + 2, + MysqlType.TIME.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_time_table`") + yield assert(count)(equalTo(1)) + } + }, + test("java.time.LocalDate values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_date_table`(`c1` DATE NOT NULL, `c2` DATE NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_date_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setDate(1, LocalDate.of(2020, 1, 1)) *> preparedStatement.setNull( + 2, + MysqlType.DATE.jdbcType + ) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_date_table`") + yield assert(count)(equalTo(1)) + } + }, + test("java.time.LocalDateTime values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate( + "CREATE TABLE `client_statement_datetime_table`(`c1` TIMESTAMP NOT NULL, `c2` TIMESTAMP NULL)" + ) + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_datetime_table`(`c1`, `c2`) VALUES (?, ?)") + count <- preparedStatement.setTimestamp(1, LocalDateTime.of(2020, 1, 1, 12, 34, 56)) *> preparedStatement + .setNull(2, MysqlType.TIMESTAMP.jdbcType) *> preparedStatement.executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_datetime_table`") + yield assert(count)(equalTo(1)) + } + }, + test("java.time.Year values can be set as parameters") { + datasource.getConnection.use { conn => + for + statement <- conn.createStatement() + _ <- statement.executeUpdate("CREATE TABLE `client_statement_year_table`(`c1` YEAR NOT NULL, `c2` YEAR NULL)") + preparedStatement <- + conn.prepareStatement("INSERT INTO `client_statement_year_table`(`c1`, `c2`) VALUES (?, ?)") + count <- + preparedStatement.setInt(1, 2020) *> preparedStatement + .setNull(2, MysqlType.YEAR.jdbcType) *> preparedStatement + .executeUpdate() + _ <- statement.executeUpdate("DROP TABLE `client_statement_year_table`") + yield assert(count)(equalTo(1)) + } + } + ) diff --git a/plugin/src/main/scala/ldbc/sbt/Dependencies.scala b/plugin/src/main/scala/ldbc/sbt/Dependencies.scala index 87548c3d6..f34fcd41e 100644 --- a/plugin/src/main/scala/ldbc/sbt/Dependencies.scala +++ b/plugin/src/main/scala/ldbc/sbt/Dependencies.scala @@ -17,7 +17,6 @@ trait Dependencies { val ldbcQueryBuilder = component("ldbc-query-builder") val ldbcSchema = component("ldbc-schema") val ldbcCodegen = component("ldbc-codegen") - val ldbcHikari = component("ldbc-hikari") val jdbcConnector = component("jdbc-connector") val ldbcConnector = component("ldbc-connector") } diff --git a/project/BuildSettings.scala b/project/BuildSettings.scala index 0b498e202..9141e4955 100644 --- a/project/BuildSettings.scala +++ b/project/BuildSettings.scala @@ -65,17 +65,6 @@ object BuildSettings { ) ) - /** A project that runs in the sbt runtime. */ - object LepusSbtProject { - def apply(name: String, dir: String): Project = - Project(name, file(dir)) - .settings(scalaVersion := scala3) - .settings(scalacOptions ++= additionalSettings) - .settings(scalacOptions --= removeSettings) - .settings(commonSettings) - .enablePlugins(AutomateHeaderPlugin) - } - /** A project that is an sbt plugin. */ object LepusSbtPluginProject { def apply(name: String, dir: String): Project = diff --git a/project/Dependencies.scala b/project/Dependencies.scala deleted file mode 100644 index 012eb8779..000000000 --- a/project/Dependencies.scala +++ /dev/null @@ -1,33 +0,0 @@ -/** This file is part of the ldbc. For the full copyright and license information, please view the LICENSE file that was - * distributed with this source code. - */ - -import sbt.* - -import ScalaVersions.* - -object Dependencies { - - val catsEffect = "org.typelevel" %% "cats-effect" % "3.6.3" - - val parserCombinators = "org.scala-lang.modules" %% "scala-parser-combinators" % "2.3.0" - - val mysqlVersion = "8.4.0" - val mysql = "com.mysql" % "mysql-connector-j" % mysqlVersion - - val typesafeConfig = "com.typesafe" % "config" % "1.4.5" - - val hikariCP = "com.zaxxer" % "HikariCP" % "7.0.2" - - val scala3Compiler = "org.scala-lang" %% "scala3-compiler" % scala3 - - val doobie = "org.tpolecat" %% "doobie-core" % "1.0.0-RC11" - - val slick = "com.typesafe.slick" %% "slick" % "3.6.1" - - val specs2Version = "5.6.4" - val specs2: Seq[ModuleID] = Seq( - "specs2-core", - "specs2-junit" - ).map("org.specs2" %% _ % specs2Version % Test) -} diff --git a/project/LaikaSettings.scala b/project/LaikaSettings.scala index 905808007..57c8bf44d 100644 --- a/project/LaikaSettings.scala +++ b/project/LaikaSettings.scala @@ -31,9 +31,10 @@ object LaikaSettings { val v02: Version = version(LdbcVersions.v02) val v03: Version = version(LdbcVersions.v03) - val v04: Version = version(LdbcVersions.latest, "Stable") - val current: Version = v04 - val all: Seq[Version] = Seq(v04, v03, v02) + val v04: Version = version(LdbcVersions.v04, "Stable") + val v05: Version = version(LdbcVersions.latest, "Dev") + val current: Version = v05 + val all: Seq[Version] = Seq(v05, v04, v03, v02) val config: Versions = Versions .forCurrentVersion(current) diff --git a/project/Versions.scala b/project/Versions.scala index 5bd15abd5..e9fd48d2a 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -3,7 +3,8 @@ */ object LdbcVersions { - val latest = "0.4" + val latest = "0.5" + val v04 = "0.4" val v03 = "0.3" val v02 = "0.2" val v01 = "0.1" diff --git a/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala b/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala index c5e9369cd..1c634b748 100644 --- a/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/ConnectionTest.scala @@ -126,7 +126,7 @@ trait ConnectionTest extends CatsEffectSuite: assertIO( datasource().getConnection.use(_.getMetaData().map(_.getDriverVersion())), if prefix == "jdbc" then "mysql-connector-j-8.4.0 (Revision: 1c3f5c149e0bfe31c7fbeb24e2d260cd890972c4)" - else "ldbc-connector-0.4.1" + else "ldbc-connector-0.5.0" ) } diff --git a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala index 9d57a9ba6..f327c244a 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DBIOTest.scala @@ -15,6 +15,7 @@ import munit.CatsEffectSuite import ldbc.dsl.* import ldbc.connector.* +import ldbc.connector.exception.* import ldbc.Connector @@ -102,3 +103,40 @@ class DBIOTest extends CatsEffectSuite: program.attempt.readOnly(connector).map(_.isLeft) ) } + + test("DBIO#updateRaw") { + assertIO( + (for + r1 <- DBIO.updateRaw("CREATE DATABASE `dbio`;").commit(connector) + r2 <- DBIO.updateRaw("DROP DATABASE `dbio`;").commit(connector) + yield List(r1, r2)), + List(1, 0) + ) + } + + test("DBIO#updateRaw#Exception") { + interceptIO[SQLSyntaxErrorException]( + DBIO.updateRaw("CREATE `dbio`;").commit(connector) + ) + } + + test("DBIO#updateRaws") { + val sql = """ + |CREATE DATABASE `dbio`; + |DROP DATABASE `dbio`; + |""".stripMargin + assertIO( + DBIO.updateRaws(sql).commit(connector).map(_.toList), + List(1, 0) + ) + } + + test("DBIO#updateRaws#Exception") { + val sql = """ + |CREATE DATABASE `dbio` + |DROP DATABASE `dbio` + |""".stripMargin + interceptIO[BatchUpdateException]( + DBIO.updateRaws(sql).commit(connector).map(_.toList) + ) + } diff --git a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala index e9d630223..7618db39f 100644 --- a/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala +++ b/tests/shared/src/test/scala/ldbc/tests/DatabaseMetaDataTest.scala @@ -160,7 +160,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: yield metaData.getDriverVersion() }, if prefix == "jdbc" then "mysql-connector-j-8.4.0 (Revision: 1c3f5c149e0bfe31c7fbeb24e2d260cd890972c4)" - else "ldbc-connector-0.4.1" + else "ldbc-connector-0.5.0" ) } @@ -180,7 +180,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: for metaData <- conn.getMetaData() yield metaData.getDriverMinorVersion() }, - if prefix == "jdbc" then 4 else 4 + if prefix == "jdbc" then 4 else 5 ) } @@ -2100,7 +2100,7 @@ trait DatabaseMetaDataTest extends CatsEffectSuite: for metaData <- conn.getMetaData() yield metaData.getJDBCMinorVersion() }, - if prefix == "jdbc" then 2 else 4 + if prefix == "jdbc" then 2 else 5 ) }