diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 8d68e74dbf874..b5269da035f3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -124,9 +124,9 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } - def sizeInBytes(): Int = { + def sizeInBytes(): Long = { var i = 0 - var bytes = 0 + var bytes = 0L while (i < fields.size) { bytes += fields(i).getSizeInBytes() i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 38fd081a104f6..bae1a2aa0d5da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -297,8 +297,8 @@ class SliceBytesArrowOutputProcessorImpl( } } - private def getBatchBytes(root: VectorSchemaRoot): Int = { - var batchBytes = 0 + private def getBatchBytes(root: VectorSchemaRoot): Long = { + var batchBytes = 0L root.getFieldVectors.asScala.foreach { vector => batchBytes += vector.getBufferSize } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala index aa6aca5076243..bbd2420c588f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriterSuite.scala @@ -75,7 +75,7 @@ class BaseStreamingArrowWriterSuite extends SparkFunSuite { () } - when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter } + when(arrowWriter.sizeInBytes()).thenAnswer { _ => sizeCounter.toLong } // Set arrowMaxBytesPerBatch to 1 transformWithStateInPySparkWriter = new BaseStreamingArrowWriter(