Skip to content

Commit 11a8cef

Browse files
committed
[SPARK-53342][SQL] Fix Arrow converter to handle multiple record batches in single IPC stream
### What changes were proposed in this pull request? This PR adds a new method in ArrowConverters that allows properly decoding an Arrow IPC stream, which can contain multiple record batches. All of the other methods can only deal with message streams that contain exactly one record batch. ### Why are the changes needed? Previously, when an Arrow IPC stream contained multiple record batches, only the first batch would be processed and the remaining batches would be ignored. This resulted in data loss and incorrect results when working with Arrow data that was serialized as a single stream with multiple batches. ### Does this PR introduce _any_ user-facing change? Yes. This fixes a data correctness issue where users would lose data when processing Arrow streams with multiple batches. The behavior change is that all batches in a stream are now correctly processed instead of only the first one. ### How was this patch tested? Added comprehensive test cases. ### Was this patch authored or co-authored using generative AI tooling? Tests Generated-by: Claude Code 🤖 Generated with [Claude Code](https://claude.ai/code) Closes apache#52090 from grundprinzip/SPARK-53342. Authored-by: Martin Grund <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 83d5ff1 commit 11a8cef

File tree

3 files changed

+538
-3
lines changed

3 files changed

+538
-3
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,9 +1308,8 @@ class SparkConnectPlanner(
13081308
}
13091309

13101310
if (rel.hasData) {
1311-
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
1312-
Iterator(rel.getData.toByteArray),
1313-
TaskContext.get())
1311+
val (rows, structType) =
1312+
ArrowConverters.fromIPCStream(rel.getData.toByteArray, TaskContext.get())
13141313
if (structType == null) {
13151314
throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.")
13161315
}

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,109 @@ private[sql] object ArrowConverters extends Logging {
264264
}
265265
}
266266

267+
/**
268+
* This is a class that converts input data in the form of a Byte array to InternalRow instances
269+
* implementing the Iterator interface.
270+
*
271+
* The input data must be a valid Arrow IPC stream, this means that the first message is always
272+
* the schema followed by N record batches.
273+
*
274+
* @param input Input Data
275+
* @param context Task Context for Spark
276+
*/
277+
private[sql] class InternalRowIteratorFromIPCStream(
278+
input: Array[Byte],
279+
context: TaskContext) extends Iterator[InternalRow] {
280+
281+
// Keep all the resources we have opened in order, should be closed
282+
// in reverse order finally.
283+
private val resources = new ArrayBuffer[AutoCloseable]()
284+
285+
// Create an allocator used for all Arrow related memory.
286+
protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator(
287+
s"to${this.getClass.getSimpleName}",
288+
0,
289+
Long.MaxValue)
290+
resources.append(allocator)
291+
292+
private val reader = try {
293+
new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
294+
} catch {
295+
case e: Exception =>
296+
closeAll(resources.toSeq.reverse: _*)
297+
throw new IllegalArgumentException(
298+
s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
299+
}
300+
resources.append(reader)
301+
302+
private val root: VectorSchemaRoot = try {
303+
reader.getVectorSchemaRoot
304+
} catch {
305+
case e: Exception =>
306+
closeAll(resources.toSeq.reverse: _*)
307+
throw new IllegalArgumentException(
308+
s"Failed to read schema from IPC stream: ${e.getMessage}", e)
309+
}
310+
resources.append(root)
311+
312+
val schema: StructType = try {
313+
ArrowUtils.fromArrowSchema(root.getSchema)
314+
} catch {
315+
case e: Exception =>
316+
closeAll(resources.toSeq.reverse: _*)
317+
throw new IllegalArgumentException(s"Failed to convert Arrow schema: ${e.getMessage}", e)
318+
}
319+
320+
// TODO: wrap in exception
321+
private var rowIterator: Iterator[InternalRow] = vectorSchemaRootToIter(root)
322+
323+
// Metrics to track batch processing
324+
private var _batchesLoaded: Int = 0
325+
private var _totalRowsProcessed: Long = 0L
326+
327+
if (context != null) {
328+
context.addTaskCompletionListener[Unit] { _ =>
329+
closeAll(resources.toSeq.reverse: _*)
330+
}
331+
}
332+
333+
// Public accessors for metrics
334+
def batchesLoaded: Int = _batchesLoaded
335+
def totalRowsProcessed: Long = _totalRowsProcessed
336+
337+
// Loads the next batch from the Arrow reader and returns true or
338+
// false if the next batch could be loaded.
339+
private def loadNextBatch(): Boolean = {
340+
if (reader.loadNextBatch()) {
341+
rowIterator = vectorSchemaRootToIter(root)
342+
_batchesLoaded += 1
343+
true
344+
} else {
345+
false
346+
}
347+
}
348+
349+
override def hasNext: Boolean = {
350+
if (rowIterator.hasNext) {
351+
true
352+
} else {
353+
if (!loadNextBatch()) {
354+
false
355+
} else {
356+
hasNext
357+
}
358+
}
359+
}
360+
361+
override def next(): InternalRow = {
362+
if (!hasNext) {
363+
throw new NoSuchElementException("No more elements in iterator")
364+
}
365+
_totalRowsProcessed += 1
366+
rowIterator.next()
367+
}
368+
}
369+
267370
/**
268371
* An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should
269372
* implement [[nextBatch]] to parse data from binary records.
@@ -382,6 +485,23 @@ private[sql] object ArrowConverters extends Logging {
382485
(iterator, iterator.schema)
383486
}
384487

488+
/**
489+
* Creates an iterator from a Byte array to deserialize an Arrow IPC stream with exactly
490+
* one schema and a varying number of record batches. Returns an iterator over the
491+
* created InternalRow.
492+
*/
493+
private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
494+
(Iterator[InternalRow], StructType) = {
495+
fromIPCStreamWithIterator(input, context)
496+
}
497+
498+
// Overloaded method for tests to access the iterator with metrics
499+
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: TaskContext):
500+
(InternalRowIteratorFromIPCStream, StructType) = {
501+
val iterator = new InternalRowIteratorFromIPCStream(input, context)
502+
(iterator, iterator.schema)
503+
}
504+
385505
/**
386506
* Convert an arrow batch container into an iterator of InternalRow.
387507
*/

0 commit comments

Comments
 (0)