Skip to content

Commit c170ec9

Browse files
committed
[SPARK-53342][SQL] Fix Arrow converter to handle multiple record batches in single IPC stream
This PR adds a new method in ArrowConverters that allows properly decoding an Arrow IPC stream, which can contain multiple record batches. All of the other methods can only deal with message streams that contain exactly one record batch. Previously, when an Arrow IPC stream contained multiple record batches, only the first batch would be processed and the remaining batches would be ignored. This resulted in data loss and incorrect results when working with Arrow data that was serialized as a single stream with multiple batches. Yes. This fixes a data correctness issue where users would lose data when processing Arrow streams with multiple batches. The behavior change is that all batches in a stream are now correctly processed instead of only the first one. Added comprehensive test cases. Tests Generated-by: Claude Code 🤖 Generated with [Claude Code](https://claude.ai/code) Closes apache#52090 from grundprinzip/SPARK-53342. Authored-by: Martin Grund <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent d3d84e0 commit c170ec9

File tree

3 files changed

+537
-2
lines changed

3 files changed

+537
-2
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,7 @@ class SparkConnectPlanner(val session: SparkSession) {
704704
}
705705

706706
if (rel.hasData) {
707-
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
708-
Iterator(rel.getData.toByteArray),
707+
val (rows, structType) = ArrowConverters.fromIPCStream(rel.getData.toByteArray,
709708
TaskContext.get())
710709
if (structType == null) {
711710
throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.")

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,109 @@ private[sql] object ArrowConverters extends Logging {
234234
}
235235
}
236236

237+
/**
238+
* This is a class that converts input data in the form of a Byte array to InternalRow instances
239+
* implementing the Iterator interface.
240+
*
241+
* The input data must be a valid Arrow IPC stream, this means that the first message is always
242+
* the schema followed by N record batches.
243+
*
244+
* @param input Input Data
245+
* @param context Task Context for Spark
246+
*/
247+
private[sql] class InternalRowIteratorFromIPCStream(
248+
input: Array[Byte],
249+
context: TaskContext) extends Iterator[InternalRow] {
250+
251+
// Keep all the resources we have opened in order, should be closed
252+
// in reverse order finally.
253+
private val resources = new ArrayBuffer[AutoCloseable]()
254+
255+
// Create an allocator used for all Arrow related memory.
256+
protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator(
257+
s"to${this.getClass.getSimpleName}",
258+
0,
259+
Long.MaxValue)
260+
resources.append(allocator)
261+
262+
private val reader = try {
263+
new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
264+
} catch {
265+
case e: Exception =>
266+
closeAll(resources.toSeq.reverse: _*)
267+
throw new IllegalArgumentException(
268+
s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
269+
}
270+
resources.append(reader)
271+
272+
private val root: VectorSchemaRoot = try {
273+
reader.getVectorSchemaRoot
274+
} catch {
275+
case e: Exception =>
276+
closeAll(resources.toSeq.reverse: _*)
277+
throw new IllegalArgumentException(
278+
s"Failed to read schema from IPC stream: ${e.getMessage}", e)
279+
}
280+
resources.append(root)
281+
282+
val schema: StructType = try {
283+
ArrowUtils.fromArrowSchema(root.getSchema)
284+
} catch {
285+
case e: Exception =>
286+
closeAll(resources.toSeq.reverse: _*)
287+
throw new IllegalArgumentException(s"Failed to convert Arrow schema: ${e.getMessage}", e)
288+
}
289+
290+
// TODO: wrap in exception
291+
private var rowIterator: Iterator[InternalRow] = vectorSchemaRootToIter(root)
292+
293+
// Metrics to track batch processing
294+
private var _batchesLoaded: Int = 0
295+
private var _totalRowsProcessed: Long = 0L
296+
297+
if (context != null) {
298+
context.addTaskCompletionListener[Unit] { _ =>
299+
closeAll(resources.toSeq.reverse: _*)
300+
}
301+
}
302+
303+
// Public accessors for metrics
304+
def batchesLoaded: Int = _batchesLoaded
305+
def totalRowsProcessed: Long = _totalRowsProcessed
306+
307+
// Loads the next batch from the Arrow reader and returns true or
308+
// false if the next batch could be loaded.
309+
private def loadNextBatch(): Boolean = {
310+
if (reader.loadNextBatch()) {
311+
rowIterator = vectorSchemaRootToIter(root)
312+
_batchesLoaded += 1
313+
true
314+
} else {
315+
false
316+
}
317+
}
318+
319+
override def hasNext: Boolean = {
320+
if (rowIterator.hasNext) {
321+
true
322+
} else {
323+
if (!loadNextBatch()) {
324+
false
325+
} else {
326+
hasNext
327+
}
328+
}
329+
}
330+
331+
override def next(): InternalRow = {
332+
if (!hasNext) {
333+
throw new NoSuchElementException("No more elements in iterator")
334+
}
335+
_totalRowsProcessed += 1
336+
rowIterator.next()
337+
}
338+
}
339+
237340
/**
238341
* An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should
239342
* implement [[nextBatch]] to parse data from binary records.
@@ -345,6 +448,23 @@ private[sql] object ArrowConverters extends Logging {
345448
(iterator, iterator.schema)
346449
}
347450

451+
/**
452+
* Creates an iterator from a Byte array to deserialize an Arrow IPC stream with exactly
453+
* one schema and a varying number of record batches. Returns an iterator over the
454+
* created InternalRow.
455+
*/
456+
private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
457+
(Iterator[InternalRow], StructType) = {
458+
fromIPCStreamWithIterator(input, context)
459+
}
460+
461+
// Overloaded method for tests to access the iterator with metrics
462+
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: TaskContext):
463+
(InternalRowIteratorFromIPCStream, StructType) = {
464+
val iterator = new InternalRowIteratorFromIPCStream(input, context)
465+
(iterator, iterator.schema)
466+
}
467+
348468
/**
349469
* Convert an arrow batch container into an iterator of InternalRow.
350470
*/

0 commit comments

Comments
 (0)