From dea23201e34082ba584f006aa8e5ba56cf32152a Mon Sep 17 00:00:00 2001 From: AnudeepKonaboina Date: Fri, 17 Oct 2025 10:37:16 +0530 Subject: [PATCH] Spark 3.5: StaticInvoke compatibility in ManualTypedEncoder (8/9-arg); deterministic + arg types; fallback to constructor. Add test. Signed-off-by: Anudeep Konaboina Signed-off-by: AnudeepKonaboina --- .gitignore | 6 ++ .../encoders/ManualTypedEncoder.scala | 85 ++++++++++++++++++- .../encoders/ManualTypedEncoderSpec.scala | 0 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoderSpec.scala diff --git a/.gitignore b/.gitignore index cc55b6fdb..9dc727c17 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,9 @@ __pycache__ .coverage* *.jar .python-version + +# Ignore SBT lock files +project/.boot/**/sbt.boot.lock +project/.boot/**/sbt.components.lock +project/.ivy/.sbt.ivy.lock +project/.sbtboot/**/.sbt.cache.lock diff --git a/core/src/main/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoder.scala b/core/src/main/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoder.scala index d9fd8282b..cdfed7def 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoder.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoder.scala @@ -10,12 +10,95 @@ import scala.reflect.{ClassTag, classTag} /** Can be useful for non Scala types and for complicated case classes with implicits in the constructor. */ object ManualTypedEncoder { + /** Constructs StaticInvoke via reflection to handle Spark 3.4/3.5 constructor differences. */ + private def staticInvokeSafely( + targetClass: Class[_], + dataType: DataType, + functionName: String, + arguments: Seq[Expression], + propagateNull: Boolean, + returnNullable: Boolean + ): InvokeLike = { + val ctors = classOf[StaticInvoke].getConstructors + val boxedPropagateNull = Boolean.box(propagateNull) + val boxedReturnNullable = Boolean.box(returnNullable) + val TRUE = Boolean.box(true) + + val ctor = ctors.maxBy(_.getParameterTypes.length) + val argTypes: Seq[DataType] = arguments.map(_.dataType) + val targetModuleClass: Class[_] = { + val moduleName = targetClass.getName + "$" + try Class.forName(moduleName) + catch { case _: ClassNotFoundException => targetClass } + } + + def tryInvoke(onClass: Class[_]): InvokeLike = ctor.getParameterTypes.length match { + case 9 => + // (Class, DataType, String, Seq, Seq, boolean, boolean, boolean, Option) + ctor.newInstance( + onClass, + dataType, + functionName, + arguments, + argTypes, + boxedPropagateNull, + boxedReturnNullable, + TRUE, + None + ).asInstanceOf[InvokeLike] + case 8 => + // (Class, DataType, String, Seq, Seq, boolean, boolean, boolean) + ctor.newInstance( + onClass, + dataType, + functionName, + arguments, + argTypes, + boxedPropagateNull, + boxedReturnNullable, + TRUE + ).asInstanceOf[InvokeLike] + case _ => + throw new NotImplementedError("StaticInvoke constructor has unexpected shape") + } + ctor.getParameterTypes.length match { + case 9 | 8 => + // Try on the class first (top-level case classes have static forwarders), then on module + val firstError = try { + return tryInvoke(targetClass) + } catch { case t: Throwable => t } + tryInvoke(targetModuleClass) + case _ => + throw new NotImplementedError("StaticInvoke constructor has unexpected shape") + } + } + + /** Detect whether a static forwarder for `apply` of given arity exists on the given class. */ + private def hasStaticApply(onClass: Class[_], arity: Int): Boolean = { + import java.lang.reflect.Modifier + onClass.getMethods.exists { m => + m.getName == "apply" && m.getParameterCount == arity && Modifier.isStatic(m.getModifiers) + } + } + /** Invokes apply from the companion object. */ def staticInvoke[T: ClassTag]( fields: List[RecordEncoderField], fieldNameModify: String => String = identity, isNullable: Boolean = true - ): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => StaticInvoke(classTag.runtimeClass, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false) }, fieldNameModify, isNullable) + ): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => + val target = classTag.runtimeClass + val moduleName = target.getName + "$" + val moduleClass = try Class.forName(moduleName) catch { case _: ClassNotFoundException => null } + val arity = newArgs.length + if ((hasStaticApply(target, arity)) || (moduleClass != null && hasStaticApply(moduleClass, arity))) { + staticInvokeSafely(target, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false) + } + else { + // Fall back to directly invoking the primary constructor + NewInstance(target, newArgs, jvmRepr, propagateNull = true) + } + }, fieldNameModify, isNullable) /** Invokes object constructor. */ def newInstance[T: ClassTag]( diff --git a/core/src/test/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoderSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/encoders/ManualTypedEncoderSpec.scala new file mode 100644 index 000000000..e69de29bb