diff --git a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala index 735941e679c9..f6247bfd4f0b 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/main/scala/ml/dmlc/xgboost4j/scala/spark/ExternalMemory.scala @@ -18,10 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import java.nio.file.{Files, Paths} +import java.util.concurrent.Executors +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.DurationInt import ai.rapids.cudf._ +import org.apache.commons.logging.LogFactory import ml.dmlc.xgboost4j.java.{ColumnBatch, CudfColumnBatch} import ml.dmlc.xgboost4j.scala.spark.Utils.withResource @@ -61,14 +66,23 @@ private[spark] trait ExternalMemory[T] extends Iterator[Table] with AutoCloseabl } // The data will be cached into disk. -private[spark] class DiskExternalMemoryIterator(val path: String) extends ExternalMemory[String] { +private[spark] class DiskExternalMemoryIterator(val parent: String) extends ExternalMemory[String] { + + private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin") private lazy val root = { - val tmp = path + "/xgboost" + val tmp = parent + "/xgboost" createDirectory(tmp) tmp } + logger.info(s"DiskExternalMemoryIterator createDirectory $root") + + // Tasks mapping the path to the Future of caching table + private val taskFutures: mutable.HashMap[String, Future[Boolean]] = mutable.HashMap.empty + private val executor = Executors.newFixedThreadPool(2) + implicit val ec = ExecutionContext.fromExecutor(executor) + private var counter = 0 private def createDirectory(dirPath: String): Unit = { @@ -78,6 +92,31 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern } } + /** + * Cache the table into disk which runs in a separate thread + * + * @param table to be cached + * @param path where to cache the table + */ + private def cacheTableThread(table: Table, path: String): Future[Boolean] = { + Future { + withResource(table) { _ => + try { + val names = (1 to table.getNumberOfColumns).map(_.toString) + val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build() + withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer => + writer.write(table) + } + true + } catch { + case e: Throwable => + throw e + false + } + } + } + } + /** * Convert the table to file path which will be cached * @@ -85,13 +124,13 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern * @return the content */ override def convertTable(table: Table): String = { - val names = (1 to table.getNumberOfColumns).map(_.toString) - val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build() - val path = root + "/table_" + counter + "_" + System.nanoTime(); + val path = root + "/table_" + counter + "_" + System.nanoTime() counter += 1 - withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer => - writer.write(table) - } + + // Increase the reference count of columnars to avoid being recycled + val newTable = new Table((0 until table.getNumberOfColumns).map(table.getColumn): _*) + val future = cacheTableThread(newTable, path) + taskFutures += (path -> future) path } @@ -106,18 +145,32 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern } } + private def checkAndWaitCachingDone(path: String): Unit = { + val futureOpt = taskFutures.get(path) + if (futureOpt.isEmpty) { + throw new RuntimeException(s"Failed to find the caching process for $path") + } + // Wait 6s to check if the caching is done. + // TODO, make it configurable + // If timeout, it's going to throw exception + val success = Await.result(futureOpt.get, 6.seconds) + if (!success) { // Failed to cache + throw new RuntimeException(s"Failed to cache table to $path") + } + } + /** * Load the path from disk to the Table * - * @param name to be loaded + * @param path to be loaded * @return Table */ - override def loadTable(name: String): Table = { - val file = new File(name) - if (!file.exists()) { - throw new RuntimeException(s"The cache file ${name} doesn't exist" ) - } + override def loadTable(path: String): Table = { + val file = new File(path) + try { + checkAndWaitCachingDone(path) + withResource(Table.readArrowIPCChunked(file)) { reader => val tables = ArrayBuffer.empty[Table] closeOnExcept(tables) { tables => @@ -147,6 +200,7 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern } override def close(): Unit = { + executor.shutdown() buffers.foreach { path => val file = new File(path) if (file.exists()) { @@ -169,7 +223,7 @@ private[spark] object ExternalMemory { * * The first round iteration gets the input batch that will be * 1. cached in the external memory - * 2. fed in QuantilDmatrix + * 2. fed in QuantileDMatrix * The second round iteration returns the cached batch got from external memory. * * @param input the spark input iterator