diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 850bfb4a89857..6a6c99e56b1e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -248,7 +248,8 @@ class ParquetFileFormat pushDownStringPredicate, pushDownInFilterThreshold, isCaseSensitive, - datetimeRebaseSpec) + datetimeRebaseSpec, + variantExtractionSchema = Some(requiredSchema)) filters // Collects all converted Parquet filter predicates. Notice that not all predicates // can be converted (`ParquetFilters.createFilter` returns an `Option`). That's why diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 4a9b17bf98e59..043240b1443a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -37,10 +37,14 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition +import org.apache.spark.sql.catalyst.expressions.variant.ObjectExtraction import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, RebaseSpec} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper +import org.apache.spark.sql.execution.datasources.VariantMetadata import org.apache.spark.sql.internal.LegacyBehaviorPolicy import org.apache.spark.sql.sources +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -55,12 +59,17 @@ class ParquetFilters( pushDownStringPredicate: Boolean, pushDownInFilterThreshold: Int, caseSensitive: Boolean, - datetimeRebaseSpec: RebaseSpec) { + datetimeRebaseSpec: RebaseSpec, + variantExtractionSchema: Option[StructType] = None) { // A map which contains parquet field name and data type, if predicate push down applies. // // Each key in `nameToParquetField` represents a column; `dots` are used as separators for // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion. // See `org.apache.spark.sql.connector.catalog.quote` for implementation details. + // + // If `variantExtractionSchema` is provided, extra entries are added for variant struct fields + // produced by PushVariantIntoScan: logical paths like "v.`0`" are mapped to the corresponding + // physical typed_value primitive to enable row-group skipping on shredded columns. private val nameToParquetField : Map[String, ParquetPrimitiveField] = { def getNormalizedLogicalType(p: PrimitiveType): LogicalTypeAnnotation = { // SPARK-40280: Signed 64 bits on an INT64 and signed 32 bits on an INT32 are optional, but @@ -98,18 +107,99 @@ class ParquetFilters( } } + // For a shredded variant physical group and a sequence of object-extraction keys from a + // VariantMetadata path, return the leaf typed_value primitive as a ParquetPrimitiveField. + // The shredding layout is regular: the physical path is always + // physPrefix / variantCol / typed_value / k0 / typed_value / ... / kN / typed_value + // schema.containsPath / schema.getType navigate this path in one shot. + // Returns None if any segment is absent or the leaf is not a non-REPEATED primitive. + def resolveShreddedPrimitive( + variantGroup: GroupType, + keys: Array[String], + physPrefix: Array[String]): Option[ParquetPrimitiveField] = { + if (keys.isEmpty) return None + val TYPED = "typed_value" + val physPath = physPrefix ++ Array(variantGroup.getName, TYPED) ++ + keys.flatMap(k => Array(k, TYPED)) + if (!schema.containsPath(physPath)) return None + schema.getType(physPath: _*) match { + case p: PrimitiveType if p.getRepetition != Repetition.REPEATED => + Some(ParquetPrimitiveField(physPath, + ParquetSchemaType( + getNormalizedLogicalType(p), p.getPrimitiveTypeName, p.getTypeLength))) + case _ => None + } + } + + // Loop through variant fields alongside the physical Parquet group, collecting extra + // name -> ParquetPrimitiveField entries for variant struct fields so that pushed + // filters like GreaterThan("v.0", 1000L) can be mapped to the physical shredded column. + def variantStructEntries( + variantFields: Seq[StructField], + physGroup: GroupType, + parentNames: Array[String]): Seq[(String, ParquetPrimitiveField)] = { + variantFields.flatMap { field => + val physFieldOpt = physGroup.getFields.asScala.collectFirst { + case g: GroupType if + (if (caseSensitive) g.getName == field.name + else g.getName.equalsIgnoreCase(field.name)) => g + } + physFieldOpt match { + case None => Nil + case Some(physChild) => + field.dataType match { + // Variant struct: each child is a requested variant extraction with VariantMetadata. + case s: StructType if VariantMetadata.isVariantStruct(s) => + s.fields.flatMap { extraction => + if (!extraction.metadata.contains(VariantMetadata.METADATA_KEY)) Nil + else { + val meta = VariantMetadata.fromMetadata(extraction.metadata) + val segments = try { meta.parsedPath() } catch { case _: Exception => null } + if (segments == null) Nil + else { + // Only scalar object-extraction paths are eligible for statistics pushdown. + val keys = segments.collect { case o: ObjectExtraction => o.key } + if (keys.length != segments.length) Nil + else resolveShreddedPrimitive(physChild, keys, parentNames) match { + case None => Nil + case Some(primField) => + val logPath = (parentNames :+ field.name :+ extraction.name) + .toImmutableArraySeq.quoted + Seq(logPath -> primField) + } + } + } + } + // Ordinary struct: recurse into nested fields. + case s: StructType if !VariantMetadata.isVariantStruct(s) => + variantStructEntries(s.fields.toSeq, physChild, parentNames :+ field.name) + case _ => Nil + } + } + } + } + val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field => - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper (field.fieldNames.toImmutableArraySeq.quoted, field) } + + val extraFields = variantExtractionSchema match { + case Some(variantSchema) => + val entries = variantStructEntries( + variantSchema.fields.toSeq, schema.asGroupType(), Array.empty) + entries + case None => Nil + } + + val allFields = primitiveFields ++ extraFields if (caseSensitive) { - primitiveFields.toMap + allFields.toMap } else { // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive // mode, just skip pushdown for these fields, they will trigger Exception when reading, // See: SPARK-25132. val dedupPrimitiveFields = - primitiveFields + allFields .groupBy(_._1.toLowerCase(Locale.ROOT)) .filter(_._2.size == 1) .transform((_, v) => v.head._2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 6b73cc8618d1b..d7794ee9455c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -34,7 +34,7 @@ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, Eq, Gt, GtEq, In => FilterIn, Lt, LtEq, NotEq, UserDefinedByInstance} import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat, ParquetOutputFormat} import org.apache.parquet.hadoop.util.HadoopInputFile -import org.apache.parquet.schema.MessageType +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException} import org.apache.spark.sql._ @@ -46,7 +46,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.execution.ExplainMode -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelationWithTable, PushableColumnAndNestedColumn} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelationWithTable, PushableColumnAndNestedColumn, VariantMetadata} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions._ @@ -2358,6 +2358,229 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { } } } + + /** + * Tests for shredded variant filter pushdown in ParquetFilters. + * These tests construct ParquetFilters and directly exercise + * resolveShreddedPrimitive / variantStructEntries. + */ + private def variantParquetFilters( + parquetSchemaStr: String, + variantExtractionSchema: Option[StructType], + caseSensitive: Boolean = true): ParquetFilters = { + val parquetSchema = MessageTypeParser.parseMessageType(parquetSchemaStr) + new ParquetFilters( + parquetSchema, + conf.parquetFilterPushDownDate, + conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, + conf.parquetFilterPushDownStringPredicate, + conf.parquetFilterPushDownInFilterThreshold, + caseSensitive, + RebaseSpec(LegacyBehaviorPolicy.CORRECTED), + variantExtractionSchema = variantExtractionSchema) + } + + /** Construct a variant struct field with VariantMetadata. */ + private def variantStructField(name: String, dt: DataType, path: String): StructField = + StructField(name, dt, metadata = VariantMetadata(path, failOnError = true, "UTC").toMetadata) + + /** + * Build the variantExtractionSchema that PushVariantIntoScan produces for a top-level + * variant column `colName` with one extracted field at `variantPath` of type `dt`. + * The struct child is named by its ordinal "0". + */ + private def variantExtractionSchema( + colName: String, + variantPath: String, + dt: DataType): StructType = { + StructType(Seq( + StructField(colName, StructType(Seq(variantStructField("0", dt, variantPath)))))) + } + + test("variant shredded filter: top-level bigint field resolves to physical typed_value path") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + |}""".stripMargin + val pushDownSchema = variantExtractionSchema("v", "$.a", LongType) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + // "v.`0`" should resolve to physical column "v.typed_value.a.typed_value" + val filter = pf.createFilter(sources.GreaterThan("v.`0`", 1000L)) + assert(filter.isDefined, + "Expected filter predicate to be created via shredded column mapping v.`0`") + assert(filter.get.toString.contains("v.typed_value.a.typed_value"), + s"Expected physical path v.typed_value.a.typed_value in filter, got: ${filter.get}") + } + + test("variant shredded filter: without variantExtractionSchema logical path is unknown") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + |}""".stripMargin + val pf = variantParquetFilters(parquetSchema, variantExtractionSchema = None) + // Without variantExtractionSchema the mapping "v.`0`" is unknown -- None + assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty) + } + + test("variant shredded filter: top-level string field resolves correctly") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group b { + | optional binary value; + | optional binary typed_value (STRING); + | } + | } + | } + |}""".stripMargin + val pushDownSchema = variantExtractionSchema("v", "$.b", StringType) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + // GreaterThan with a string value should produce a predicate on the physical column + val filter = pf.createFilter(sources.GreaterThan("v.`0`", "str500")) + assert(filter.isDefined, + "Expected GreaterThan on shredded string column to produce a predicate") + } + + test("variant shredded filter: nested variant (s.v) uses full physical prefix") { + // Bug fixed by this PR: without physPrefix the path was "v.typed_value.a.typed_value" + // instead of the correct "s.v.typed_value.a.typed_value". + val parquetSchema = + """message spark_schema { + | optional group s { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + | } + |}""".stripMargin + // variantExtractionSchema: s: struct> + val pushDownSchema = StructType(Seq( + StructField("s", StructType(Seq( + StructField("v", StructType(Seq( + variantStructField("0", LongType, "$.a"))))))))) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + val filter = pf.createFilter(sources.GreaterThan("s.v.`0`", 500L)) + assert(filter.isDefined, + "Expected filter predicate for nested s.v.`0` via shredded column mapping") + assert(filter.get.toString.contains("s.v.typed_value.a.typed_value"), + s"Expected full physical path s.v.typed_value.a.typed_value, got: ${filter.get}") + } + + test("variant shredded filter: multi-level path $.a.b resolves correctly") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional group typed_value { + | optional group b { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + | } + | } + |}""".stripMargin + val pushDownSchema = variantExtractionSchema("v", "$.a.b", LongType) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + val filter = pf.createFilter(sources.LessThan("v.`0`", 100L)) + assert(filter.isDefined, + "Expected filter predicate for multi-level path $.a.b via shredded column mapping") + assert(filter.get.toString.contains("v.typed_value.a.typed_value.b.typed_value"), + s"Expected multi-level physical path in filter, got: ${filter.get}") + } + + test("variant shredded filter: field absent from physical schema falls back gracefully") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + |}""".stripMargin + val pushDownSchema = variantExtractionSchema("v", "$.missing", LongType) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + // resolveShredded returns None -- no FilterPredicate created (falls back to record-level) + assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty) + } + + test("variant shredded filter: array index path is skipped (falls back to record-level)") { + val parquetSchema = + """message spark_schema { + | optional group v { + | optional binary value; + | optional group typed_value (LIST) { + | repeated group list { + | optional group element { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + | } + |}""".stripMargin + // Path "$.arr[0]" contains an ArrayExtraction segment + val pushDownSchema = variantExtractionSchema("v", "$.arr[0]", LongType) + val pf = variantParquetFilters(parquetSchema, Some(pushDownSchema)) + assert(pf.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty) + } + + test("variant shredded filter: case-insensitive column matching") { + val parquetSchema = + """message spark_schema { + | optional group V { + | optional binary value; + | optional group typed_value { + | optional group a { + | optional binary value; + | optional int64 typed_value; + | } + | } + | } + |}""".stripMargin + // Spark schema uses lowercase "v" but Parquet column is "V" + val pushDownSchema = variantExtractionSchema("v", "$.a", LongType) + val pfSensitive = + variantParquetFilters(parquetSchema, Some(pushDownSchema), caseSensitive = true) + assert(pfSensitive.createFilter(sources.GreaterThan("v.`0`", 1000L)).isEmpty, + "Case-sensitive: 'v' should not match Parquet column 'V'") + val pfInsensitive = + variantParquetFilters(parquetSchema, Some(pushDownSchema), caseSensitive = false) + assert(pfInsensitive.createFilter(sources.GreaterThan("v.`0`", 1000L)).isDefined, + "Case-insensitive: 'v' should match Parquet column 'V'") + } } @ExtendedSQLTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantShreddingFilterPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantShreddingFilterPushdownSuite.scala new file mode 100644 index 0000000000000..4933d44154d0f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantShreddingFilterPushdownSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.AccumulatorContext + +/** + * Tests for row-group skipping on shredded Variant columns in Parquet (DSv1). + * + * When a Variant column is written with shredding enabled, each extracted field is stored as a + * separate typed Parquet column (e.g. `v.typed_value.a.typed_value` for `$.a`) with standard + * row-group min/max statistics. The PUSH_VARIANT_INTO_SCAN rule rewrites + * `variant_get(v, '$.a', 'bigint') > 100` into a struct field access `v.`0` > 100`. + * ParquetFilters then maps the logical path `v.`0`` to the physical shredded column + * `v.typed_value.a.typed_value` via VariantMetadata, enabling Parquet FilterPredicate to + * evaluate row-group min/max statistics and skip row groups that cannot contain matching rows. + */ +class VariantShreddingFilterPushdownSuite extends QueryTest with ParquetTest + with SharedSparkSession { + + private val shreddingWriteConf = Seq( + SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> "true", + SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "true", + SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "false", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true" + ) + + /** + * Writes a shredded Variant Parquet file coalesced to 1 partition with a tiny block size so + * the writer produces at least 2 non-overlapping row groups: + * - Row group 0: ids 0..999, a in [0, 999], b in ["str0".."str999"] + * - Row group 1: ids 1000..1999, a in [1000, 1999], b in ["str1000".."str1999"] + */ + private def writeTwoRowGroups(dir: java.io.File): Unit = { + withSQLConf(shreddingWriteConf: _*) { + spark.sql( + """SELECT parse_json('{"a":' || id || ', "b":"str' || id || '"}') AS v + |FROM range(0, 2000, 1, 1)""".stripMargin + ).coalesce(1) + .write + .option("parquet.block.size", 512) + .mode("overwrite") + .parquet(dir.getAbsolutePath) + } + } + + /** + * Counts how many Parquet row groups are actually read by running the DataFrame with an + * accumulator-based counter. Uses the same technique from ParquetFilterSuite. + */ + private def countRowGroupsRead(df: org.apache.spark.sql.DataFrame): Int = { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) + try { + df.foreachPartition((it: Iterator[Row]) => it.foreach(_ => accu.add(0))) + accu.value + } finally { + AccumulatorContext.remove(accu.id) + } + } + + test("variant_get integer predicate skips row groups and returns correct rows") { + // Set PARQUET_VECTORIZED_READER_ENABLED tot rue to count read of row groups. + // ParquetFilters maps "v.`0`" to "v.typed_value.a.typed_value" via VariantMetadata. + // Row group 0: a in [0, 999] -- dropped by "a > 999" (max=999 is not > 999) + // Row group 1: a in [1000, 1999] -- kept + withTempDir { dir => + writeTwoRowGroups(dir) + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true" + ) { + val dfFiltered = spark.read.parquet(dir.getAbsolutePath) + .selectExpr("variant_get(v, '$.a', 'bigint') AS a") + .where("a > 999") + val dfAll = spark.read.parquet(dir.getAbsolutePath) + .selectExpr("variant_get(v, '$.a', 'bigint') AS a") + + assert(countRowGroupsRead(dfFiltered) < countRowGroupsRead(dfAll), + "Expected at least one row group to be skipped by the shredded column statistics") + checkAnswer(dfFiltered.orderBy("a"), (1000L to 1999L).map(Row(_))) + } + } + } + + test("variant_get string predicate skips row groups and returns correct rows") { + // Set PARQUET_VECTORIZED_READER_ENABLED to true to count read of row groups. + // Row group 0: b in ["str0".."str999"]; row group 1: b in ["str1000".."str1999"]. + // "b = 'str1500'" skips row group 0 (str1500 > str999 lexicographically). + withTempDir { dir => + writeTwoRowGroups(dir) + withSQLConf( + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "true" + ) { + val dfFiltered = spark.read.parquet(dir.getAbsolutePath) + .selectExpr("variant_get(v, '$.b', 'string') AS b") + .where("b = 'str1500'") + val dfAll = spark.read.parquet(dir.getAbsolutePath) + .selectExpr("variant_get(v, '$.b', 'string') AS b") + + assert(countRowGroupsRead(dfFiltered) < countRowGroupsRead(dfAll), + "Expected at least one row group to be skipped by the shredded string column statistics") + checkAnswer(dfFiltered, Seq(Row("str1500"))) + } + } + } +}