@@ -264,6 +264,101 @@ private[sql] object ArrowConverters extends Logging {
264
264
}
265
265
}
266
266
267
+ /**
268
+ * This is a class that converts input data in the form of a Byte array to InternalRow instances
269
+ * implementing the Iterator interface.
270
+ *
271
+ * The input data must be a valid Arrow IPC stream, this means that the first message is always
272
+ * the schema followed by N record batches.
273
+ *
274
+ * @param input Input Data
275
+ * @param context Task Context for Spark
276
+ */
277
+ private [sql] class InternalRowIteratorFromIPCStream (
278
+ input : Array [Byte ],
279
+ context : TaskContext ) extends Iterator [InternalRow ] {
280
+
281
+ // Keep all the resources we have opened in order, should be closed
282
+ // in reverse order finally.
283
+ private val resources = new ArrayBuffer [AutoCloseable ]()
284
+
285
+ // Create an allocator used for all Arrow related memory.
286
+ protected val allocator : BufferAllocator = ArrowUtils .rootAllocator.newChildAllocator(
287
+ s " to ${this .getClass.getSimpleName}" ,
288
+ 0 ,
289
+ Long .MaxValue )
290
+ resources.append(allocator)
291
+
292
+ private val reader = try {
293
+ new ArrowStreamReader (new ByteArrayInputStream (input), allocator)
294
+ } catch {
295
+ case e : Exception =>
296
+ closeAll(resources.toSeq.reverse: _* )
297
+ throw new IllegalArgumentException (
298
+ s " Failed to create ArrowStreamReader: ${e.getMessage}" , e)
299
+ }
300
+ resources.append(reader)
301
+
302
+ private val root : VectorSchemaRoot = try {
303
+ reader.getVectorSchemaRoot
304
+ } catch {
305
+ case e : Exception =>
306
+ closeAll(resources.toSeq.reverse: _* )
307
+ throw new IllegalArgumentException (
308
+ s " Failed to read schema from IPC stream: ${e.getMessage}" , e)
309
+ }
310
+ resources.append(root)
311
+
312
+ val schema : StructType = try {
313
+ ArrowUtils .fromArrowSchema(root.getSchema)
314
+ } catch {
315
+ case e : Exception =>
316
+ closeAll(resources.toSeq.reverse: _* )
317
+ throw new IllegalArgumentException (s " Failed to convert Arrow schema: ${e.getMessage}" , e)
318
+ }
319
+
320
+ // TODO: wrap in exception
321
+ private var rowIterator : Iterator [InternalRow ] = vectorSchemaRootToIter(root)
322
+
323
+ // Metrics to track batch processing
324
+ private var _batchesLoaded : Int = 0
325
+ private var _totalRowsProcessed : Long = 0L
326
+
327
+ if (context != null ) {
328
+ context.addTaskCompletionListener[Unit ] { _ =>
329
+ closeAll(resources.toSeq.reverse: _* )
330
+ }
331
+ }
332
+
333
+ // Public accessors for metrics
334
+ def batchesLoaded : Int = _batchesLoaded
335
+ def totalRowsProcessed : Long = _totalRowsProcessed
336
+
337
+ // Loads the next batch from the Arrow reader and returns true or
338
+ // false if the next batch could be loaded.
339
+ private def loadNextBatch (): Boolean = {
340
+ if (reader.loadNextBatch()) {
341
+ rowIterator = vectorSchemaRootToIter(root)
342
+ _batchesLoaded += 1
343
+ true
344
+ } else {
345
+ false
346
+ }
347
+ }
348
+
349
+ override def hasNext : Boolean = {
350
+ rowIterator.hasNext || loadNextBatch()
351
+ }
352
+
353
+ override def next (): InternalRow = {
354
+ if (! hasNext) {
355
+ throw new NoSuchElementException (" No more elements in iterator" )
356
+ }
357
+ _totalRowsProcessed += 1
358
+ rowIterator.next()
359
+ }
360
+ }
361
+
267
362
/**
268
363
* An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should
269
364
* implement [[nextBatch ]] to parse data from binary records.
@@ -382,6 +477,23 @@ private[sql] object ArrowConverters extends Logging {
382
477
(iterator, iterator.schema)
383
478
}
384
479
480
+ /**
481
+ * Creates an iterator from a Byte array to deserialize an Arrow IPC stream with exactly
482
+ * one schema and a varying number of record batches. Returns an iterator over the
483
+ * created InternalRow.
484
+ */
485
+ private [sql] def fromIPCStream (input : Array [Byte ], context : TaskContext ):
486
+ (Iterator [InternalRow ], StructType ) = {
487
+ fromIPCStreamWithIterator(input, context)
488
+ }
489
+
490
+ // Overloaded method for tests to access the iterator with metrics
491
+ private [sql] def fromIPCStreamWithIterator (input : Array [Byte ], context : TaskContext ):
492
+ (InternalRowIteratorFromIPCStream , StructType ) = {
493
+ val iterator = new InternalRowIteratorFromIPCStream (input, context)
494
+ (iterator, iterator.schema)
495
+ }
496
+
385
497
/**
386
498
* Convert an arrow batch container into an iterator of InternalRow.
387
499
*/
0 commit comments