diff --git a/build.sbt b/build.sbt index 452b829..4b86f90 100644 --- a/build.sbt +++ b/build.sbt @@ -110,17 +110,23 @@ lazy val `sparksql-scalapb` = (projectMatrix in file("sparksql-scalapb")) .customRow( scalaVersions = Seq(Scala212, Scala213), axisValues = Seq(Spark32, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() + _.settings( + Test / testOptions += Tests.Filter(_ != "scalapb.spark.ExtensionsSpec") + ) ) .customRow( scalaVersions = Seq(Scala212), axisValues = Seq(Spark31, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() + _.settings( + Test / testOptions += Tests.Filter(_ != "scalapb.spark.ExtensionsSpec") + ) ) .customRow( scalaVersions = Seq(Scala212), axisValues = Seq(Spark30, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() + _.settings( + Test / testOptions += Tests.Filter(_ != "scalapb.spark.ExtensionsSpec") + ) ) ThisBuild / publishTo := sonatypePublishToBundle.value diff --git a/sparksql-scalapb/src/test/protobuf/extensions.proto b/sparksql-scalapb/src/test/protobuf/extensions.proto new file mode 100644 index 0000000..efb2510 --- /dev/null +++ b/sparksql-scalapb/src/test/protobuf/extensions.proto @@ -0,0 +1,17 @@ +syntax = "proto2"; + +option java_package = "com.example.protos"; + +message Foo { + optional string name = 1; + + extensions 100 to 199; +} + +message Baz { + extend Foo { + optional int32 bar = 126; + } + + optional int32 id = 1; +} diff --git a/sparksql-scalapb/src/test/scala/ExtensionsSpec.scala b/sparksql-scalapb/src/test/scala/ExtensionsSpec.scala new file mode 100644 index 0000000..95b2bca --- /dev/null +++ b/sparksql-scalapb/src/test/scala/ExtensionsSpec.scala @@ -0,0 +1,31 @@ +package scalapb.spark + +import com.example.protos.extensions._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import scalapb.spark.Implicits._ + +class ExtensionsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { + val spark: SparkSession = SparkSession + .builder() + .appName("ScalaPB Demo") + .master("local[2]") + .getOrCreate() + + "Creating Dataset from message with nested extension" should "work" in { + val data = Seq( + Baz(id = Some(1)), + Baz(id = Some(2)), + Baz(id = Some(3)) + ) + + val binaryDS = spark.createDataset(data.map(_.toByteArray)) + binaryDS.show() + + val protosDS = binaryDS.map(Baz.parseFrom(_)) + protosDS.show() + } +}