Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ class ParquetFileFormat
pushDownStringPredicate,
pushDownInFilterThreshold,
isCaseSensitive,
datetimeRebaseSpec)
datetimeRebaseSpec,
variantExtractionSchema = Some(requiredSchema))
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates
// can be converted (`ParquetFilters.createFilter` returns an `Option`). That's why
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition

import org.apache.spark.sql.catalyst.expressions.variant.ObjectExtraction
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.LegacyBehaviorPolicy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -55,12 +59,17 @@ class ParquetFilters(
pushDownStringPredicate: Boolean,
pushDownInFilterThreshold: Int,
caseSensitive: Boolean,
datetimeRebaseSpec: RebaseSpec) {
datetimeRebaseSpec: RebaseSpec,
variantExtractionSchema: Option[StructType] = None) {
// A map which contains parquet field name and data type, if predicate push down applies.
//
// Each key in `nameToParquetField` represents a column; `dots` are used as separators for
// nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
// See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
//
// If `variantExtractionSchema` is provided, extra entries are added for variant struct fields
// produced by PushVariantIntoScan: logical paths like "v.`0`" are mapped to the corresponding
// physical typed_value primitive to enable row-group skipping on shredded columns.
private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
def getNormalizedLogicalType(p: PrimitiveType): LogicalTypeAnnotation = {
// SPARK-40280: Signed 64 bits on an INT64 and signed 32 bits on an INT32 are optional, but
Expand Down Expand Up @@ -98,18 +107,99 @@ class ParquetFilters(
}
}

// For a shredded variant physical group and a sequence of object-extraction keys from a
// VariantMetadata path, return the leaf typed_value primitive as a ParquetPrimitiveField.
// The shredding layout is regular: the physical path is always
// physPrefix / variantCol / typed_value / k0 / typed_value / ... / kN / typed_value
// schema.containsPath / schema.getType navigate this path in one shot.
// Returns None if any segment is absent or the leaf is not a non-REPEATED primitive.
def resolveShreddedPrimitive(
variantGroup: GroupType,
keys: Array[String],
physPrefix: Array[String]): Option[ParquetPrimitiveField] = {
if (keys.isEmpty) return None
val TYPED = "typed_value"
val physPath = physPrefix ++ Array(variantGroup.getName, TYPED) ++
keys.flatMap(k => Array(k, TYPED))
if (!schema.containsPath(physPath)) return None
schema.getType(physPath: _*) match {
case p: PrimitiveType if p.getRepetition != Repetition.REPEATED =>
Some(ParquetPrimitiveField(physPath,
ParquetSchemaType(
getNormalizedLogicalType(p), p.getPrimitiveTypeName, p.getTypeLength)))
case _ => None
}
}

// Loop through variant fields alongside the physical Parquet group, collecting extra
// name -> ParquetPrimitiveField entries for variant struct fields so that pushed
// filters like GreaterThan("v.0", 1000L) can be mapped to the physical shredded column.
def variantStructEntries(
variantFields: Seq[StructField],
physGroup: GroupType,
parentNames: Array[String]): Seq[(String, ParquetPrimitiveField)] = {
variantFields.flatMap { field =>
val physFieldOpt = physGroup.getFields.asScala.collectFirst {
case g: GroupType if
(if (caseSensitive) g.getName == field.name
else g.getName.equalsIgnoreCase(field.name)) => g
}
physFieldOpt match {
case None => Nil
case Some(physChild) =>
field.dataType match {
// Variant struct: each child is a requested variant extraction with VariantMetadata.
case s: StructType if VariantMetadata.isVariantStruct(s) =>
s.fields.flatMap { extraction =>
if (!extraction.metadata.contains(VariantMetadata.METADATA_KEY)) Nil
else {
val meta = VariantMetadata.fromMetadata(extraction.metadata)
val segments = try { meta.parsedPath() } catch { case _: Exception => null }
if (segments == null) Nil
else {
// Only scalar object-extraction paths are eligible for statistics pushdown.
val keys = segments.collect { case o: ObjectExtraction => o.key }
if (keys.length != segments.length) Nil
else resolveShreddedPrimitive(physChild, keys, parentNames) match {
case None => Nil
case Some(primField) =>
val logPath = (parentNames :+ field.name :+ extraction.name)
.toImmutableArraySeq.quoted
Seq(logPath -> primField)
}
}
}
}
// Ordinary struct: recurse into nested fields.
case s: StructType if !VariantMetadata.isVariantStruct(s) =>
variantStructEntries(s.fields.toSeq, physChild, parentNames :+ field.name)
case _ => Nil
}
}
}
}

val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field =>
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
(field.fieldNames.toImmutableArraySeq.quoted, field)
}

val extraFields = variantExtractionSchema match {
case Some(variantSchema) =>
val entries = variantStructEntries(
variantSchema.fields.toSeq, schema.asGroupType(), Array.empty)
entries
case None => Nil
}

val allFields = primitiveFields ++ extraFields
if (caseSensitive) {
primitiveFields.toMap
allFields.toMap
} else {
// Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
// See: SPARK-25132.
val dedupPrimitiveFields =
primitiveFields
allFields
.groupBy(_._1.toLowerCase(Locale.ROOT))
.filter(_._2.size == 1)
.transform((_, v) => v.head._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.filter2.predicate.Operators.{Column => _, Eq, Gt, GtEq, In => FilterIn, Lt, LtEq, NotEq, UserDefinedByInstance}
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat, ParquetOutputFormat}
import org.apache.parquet.hadoop.util.HadoopInputFile
import org.apache.parquet.schema.MessageType
import org.apache.parquet.schema.{MessageType, MessageTypeParser}

import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException}
import org.apache.spark.sql._
Expand All @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelationWithTable, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelationWithTable, PushableColumnAndNestedColumn, VariantMetadata}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -2358,6 +2358,229 @@ class ParquetV1FilterSuite extends ParquetFilterSuite {
}
}
}

/**
* Tests for shredded variant filter pushdown in ParquetFilters.
* These tests construct ParquetFilters and directly exercise
* resolveShreddedPrimitive / variantStructEntries.
*/
private def variantParquetFilters(
parquetSchemaStr: String,
variantExtractionSchema: Option[StructType],
caseSensitive: Boolean = true): ParquetFilters = {
val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaStr)
new ParquetFilters(
parquetSchema,
conf.parquetFilterPushDownDate,
conf.parquetFilterPushDownTimestamp,
conf.parquetFilterPushDownDecimal,
conf.parquetFilterPushDownStringPredicate,
conf.parquetFilterPushDownInFilterThreshold,
caseSensitive,
RebaseSpec(LegacyBehaviorPolicy.CORRECTED),
variantExtractionSchema = variantExtractionSchema)
}

/** Construct a variant struct field with VariantMetadata. */
private def variantStructField(name: String, dt: DataType, path: String): StructField =
StructField(name, dt, metadata = VariantMetadata(path, failOnError = true, "UTC").toMetadata)

/**
* Build the variantExtractionSchema that PushVariantIntoScan produces for a top-level
* variant column `colName` with one extracted field at `variantPath` of type `dt`.
* The struct child is named by its ordinal "0".
*/
private def variantExtractionSchema(
colName: String,
variantPath: String,
dt: DataType): StructType = {
StructType(Seq(
StructField(colName, StructType(Seq(variantStructField("0", dt, variantPath))))))
}

test("variant shredded filter: top-level bigint field resolves to physical typed_value path") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
|}""".stripMargin
val pushDownSchema = variantExtractionSchema("v", "$.a", LongType)
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
// "v.`0`" should resolve to physical column "v.typed_value.a.typed_value"
val filter = pf.createFilter(sources.GreaterThan("v.`0`", 1000L))
assert(filter.isDefined,
"Expected filter predicate to be created via shredded column mapping v.`0`")
assert(filter.get.toString.contains("v.typed_value.a.typed_value"),
s"Expected physical path v.typed_value.a.typed_value in filter, got: ${filter.get}")
}

test("variant shredded filter: without variantExtractionSchema logical path is unknown") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
|}""".stripMargin
val pf = variantParquetFilters(parquetSchema, variantExtractionSchema = None)
// Without variantExtractionSchema the mapping "v.`0`" is unknown -- None
assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty)
}

test("variant shredded filter: top-level string field resolves correctly") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group b {
| optional binary value;
| optional binary typed_value (STRING);
| }
| }
| }
|}""".stripMargin
val pushDownSchema = variantExtractionSchema("v", "$.b", StringType)
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
// GreaterThan with a string value should produce a predicate on the physical column
val filter = pf.createFilter(sources.GreaterThan("v.`0`", "str500"))
assert(filter.isDefined,
"Expected GreaterThan on shredded string column to produce a predicate")
}

test("variant shredded filter: nested variant (s.v) uses full physical prefix") {
// Bug fixed by this PR: without physPrefix the path was "v.typed_value.a.typed_value"
// instead of the correct "s.v.typed_value.a.typed_value".
val parquetSchema =
"""message spark_schema {
| optional group s {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
| }
|}""".stripMargin
// variantExtractionSchema: s: struct<v: struct<0: bigint [VariantMetadata($.a)]>>
val pushDownSchema = StructType(Seq(
StructField("s", StructType(Seq(
StructField("v", StructType(Seq(
variantStructField("0", LongType, "$.a")))))))))
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
val filter = pf.createFilter(sources.GreaterThan("s.v.`0`", 500L))
assert(filter.isDefined,
"Expected filter predicate for nested s.v.`0` via shredded column mapping")
assert(filter.get.toString.contains("s.v.typed_value.a.typed_value"),
s"Expected full physical path s.v.typed_value.a.typed_value, got: ${filter.get}")
}

test("variant shredded filter: multi-level path $.a.b resolves correctly") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional group typed_value {
| optional group b {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
| }
| }
|}""".stripMargin
val pushDownSchema = variantExtractionSchema("v", "$.a.b", LongType)
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
val filter = pf.createFilter(sources.LessThan("v.`0`", 100L))
assert(filter.isDefined,
"Expected filter predicate for multi-level path $.a.b via shredded column mapping")
assert(filter.get.toString.contains("v.typed_value.a.typed_value.b.typed_value"),
s"Expected multi-level physical path in filter, got: ${filter.get}")
}

test("variant shredded filter: field absent from physical schema falls back gracefully") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
|}""".stripMargin
val pushDownSchema = variantExtractionSchema("v", "$.missing", LongType)
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
// resolveShredded returns None -- no FilterPredicate created (falls back to record-level)
assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty)
}

test("variant shredded filter: array index path is skipped (falls back to record-level)") {
val parquetSchema =
"""message spark_schema {
| optional group v {
| optional binary value;
| optional group typed_value (LIST) {
| repeated group list {
| optional group element {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
| }
|}""".stripMargin
// Path "$.arr[0]" contains an ArrayExtraction segment
val pushDownSchema = variantExtractionSchema("v", "$.arr[0]", LongType)
val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema))
assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty)
}

test("variant shredded filter: case-insensitive column matching") {
val parquetSchema =
"""message spark_schema {
| optional group V {
| optional binary value;
| optional group typed_value {
| optional group a {
| optional binary value;
| optional int64 typed_value;
| }
| }
| }
|}""".stripMargin
// Spark schema uses lowercase "v" but Parquet column is "V"
val pushDownSchema = variantExtractionSchema("v", "$.a", LongType)
val pfSensitive =
variantParquetFilters(parquetSchema, Some(pushDownSchema), caseSensitive = true)
assert(pfSensitive.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty,
"Case-sensitive: 'v' should not match Parquet column 'V'")
val pfInsensitive =
variantParquetFilters(parquetSchema, Some(pushDownSchema), caseSensitive = false)
assert(pfInsensitive.createFilter(sources.GreaterThan("v.`0`", 1000L)).isDefined,
"Case-insensitive: 'v' should match Parquet column 'V'")
}
}

@ExtendedSQLTest
Expand Down
Loading