Skip to content

Commit ac1f23b

Browse files
committed
next
1 parent 967f2b6 commit ac1f23b

File tree

3 files changed

+530
-2
lines changed

3 files changed

+530
-2
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,8 +1441,8 @@ class SparkConnectPlanner(
14411441
}
14421442

14431443
if (rel.hasData) {
1444-
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
1445-
Iterator(rel.getData.toByteArray),
1444+
val (rows, structType) = ArrowConverters.fromIPCStream(
1445+
rel.getData.toByteArray,
14461446
TaskContext.get())
14471447
if (structType == null) {
14481448
throw InvalidInputErrors.inputDataForLocalRelationNoSchema()

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,101 @@ private[sql] object ArrowConverters extends Logging {
264264
}
265265
}
266266

267+
/**
268+
* This is a class that converts input data in the form of a Byte array to InternalRow instances
269+
* implementing the Iterator interface.
270+
*
271+
* The input data must be a valid Arrow IPC stream, this means that the first message is always
272+
* the schema followed by N record batches.
273+
*
274+
* @param input Input Data
275+
* @param context Task Context for Spark
276+
*/
277+
private[sql] class InternalRowIteratorFromIPCStream(
278+
input: Array[Byte],
279+
context: TaskContext) extends Iterator[InternalRow] {
280+
281+
// Keep all the resources we have opened in order, should be closed
282+
// in reverse order finally.
283+
private val resources = new ArrayBuffer[AutoCloseable]()
284+
285+
// Create an allocator used for all Arrow related memory.
286+
protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator(
287+
s"to${this.getClass.getSimpleName}",
288+
0,
289+
Long.MaxValue)
290+
resources.append(allocator)
291+
292+
private val reader = try {
293+
new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
294+
} catch {
295+
case e: Exception =>
296+
closeAll(resources.toSeq.reverse: _*)
297+
throw new IllegalArgumentException(
298+
s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
299+
}
300+
resources.append(reader)
301+
302+
private val root: VectorSchemaRoot = try {
303+
reader.getVectorSchemaRoot
304+
} catch {
305+
case e: Exception =>
306+
closeAll(resources.toSeq.reverse: _*)
307+
throw new IllegalArgumentException(
308+
s"Failed to read schema from IPC stream: ${e.getMessage}", e)
309+
}
310+
resources.append(root)
311+
312+
val schema: StructType = try {
313+
ArrowUtils.fromArrowSchema(root.getSchema)
314+
} catch {
315+
case e: Exception =>
316+
closeAll(resources.toSeq.reverse: _*)
317+
throw new IllegalArgumentException(s"Failed to convert Arrow schema: ${e.getMessage}", e)
318+
}
319+
320+
// TODO: wrap in exception
321+
private var rowIterator: Iterator[InternalRow] = vectorSchemaRootToIter(root)
322+
323+
// Metrics to track batch processing
324+
private var _batchesLoaded: Int = 0
325+
private var _totalRowsProcessed: Long = 0L
326+
327+
if (context != null) {
328+
context.addTaskCompletionListener[Unit] { _ =>
329+
closeAll(resources.toSeq.reverse: _*)
330+
}
331+
}
332+
333+
// Public accessors for metrics
334+
def batchesLoaded: Int = _batchesLoaded
335+
def totalRowsProcessed: Long = _totalRowsProcessed
336+
337+
// Loads the next batch from the Arrow reader and returns true or
338+
// false if the next batch could be loaded.
339+
private def loadNextBatch(): Boolean = {
340+
if (reader.loadNextBatch()) {
341+
rowIterator = vectorSchemaRootToIter(root)
342+
_batchesLoaded += 1
343+
true
344+
} else {
345+
false
346+
}
347+
}
348+
349+
override def hasNext: Boolean = {
350+
rowIterator.hasNext || loadNextBatch()
351+
}
352+
353+
override def next(): InternalRow = {
354+
if (!hasNext) {
355+
throw new NoSuchElementException("No more elements in iterator")
356+
}
357+
_totalRowsProcessed += 1
358+
rowIterator.next()
359+
}
360+
}
361+
267362
/**
268363
* An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should
269364
* implement [[nextBatch]] to parse data from binary records.
@@ -382,6 +477,23 @@ private[sql] object ArrowConverters extends Logging {
382477
(iterator, iterator.schema)
383478
}
384479

480+
/**
481+
* Creates an iterator from a Byte array to deserialize an Arrow IPC stream with exactly
482+
* one schema and a varying number of record batches. Returns an iterator over the
483+
* created InternalRow.
484+
*/
485+
private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
486+
(Iterator[InternalRow], StructType) = {
487+
fromIPCStreamWithIterator(input, context)
488+
}
489+
490+
// Overloaded method for tests to access the iterator with metrics
491+
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: TaskContext):
492+
(InternalRowIteratorFromIPCStream, StructType) = {
493+
val iterator = new InternalRowIteratorFromIPCStream(input, context)
494+
(iterator, iterator.schema)
495+
}
496+
385497
/**
386498
* Convert an arrow batch container into an iterator of InternalRow.
387499
*/

0 commit comments

Comments
 (0)