From 076746f3589dbf87dc569d5614309c607fa725fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 2 Mar 2026 16:24:05 -0800 Subject: [PATCH] [SPARK] Fix integer overflow when computing Arrow batch bytes `ArrowWriter.sizeInBytes()` and `SliceBytesArrowOutputProcessorImpl .getBatchBytes()` both accumulated per-column buffer sizes (each an `Int`) into an `Int` accumulator. When the total exceeds 2 GB the sum silently wraps negative, causing the byte-limit checks controlled by `spark.sql.execution.arrow.maxBytesPerBatch` and `spark.sql.execution.arrow.maxBytesPerOutputBatch` to behave incorrectly and potentially allow oversized batches through. Fix by changing both accumulators and return types to `Long`. Co-Authored-By: Claude Sonnet 4.6 --- .../org/apache/spark/sql/execution/arrow/ArrowWriter.scala | 4 ++-- .../apache/spark/sql/execution/python/PythonArrowOutput.scala | 4 ++-- .../python/streaming/BaseStreamingArrowWriterSuite.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 8d68e74dbf874..b5269da035f3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -124,9 +124,9 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } - def sizeInBytes(): Int = { + def sizeInBytes(): Long = { var i = 0 - var bytes = 0 + var bytes = 0L while (i < fields.size) { bytes += fields(i).getSizeInBytes() i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 38fd081a104f6..bae1a2aa0d5da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -297,8 +297,8 @@ class SliceBytesArrowOutputProcessorImpl( } } - private def getBatchBytes(root: VectorSchemaRoot): Int = { - var batchBytes = 0 + private def getBatchBytes(root: VectorSchemaRoot): Long = { + var batchBytes = 0L root.getFieldVectors.asScala.foreach { vector => batchBytes += vector.getBufferSize } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala index aa6aca5076243..bbd2420c588f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala @@ -75,7 +75,7 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite { () } - when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter } + when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter.toLong } // Set arrowMaxBytesPerBatch to 1 transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(