Skip to content

[jvm-packages] ExtenralMemory: Overlap the caching time #11474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ package ml.dmlc.xgboost4j.scala.spark

import java.io.File
import java.nio.file.{Files, Paths}
import java.util.concurrent.Executors

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.DurationInt

import ai.rapids.cudf._
import org.apache.commons.logging.LogFactory

import ml.dmlc.xgboost4j.java.{ColumnBatch, CudfColumnBatch}
import ml.dmlc.xgboost4j.scala.spark.Utils.withResource
Expand Down Expand Up @@ -61,14 +66,23 @@ private[spark] trait ExternalMemory[T] extends Iterator[Table] with AutoCloseabl
}

// The data will be cached into disk.
private[spark] class DiskExternalMemoryIterator(val path: String) extends ExternalMemory[String] {
private[spark] class DiskExternalMemoryIterator(val parent: String) extends ExternalMemory[String] {

private val logger = LogFactory.getLog("XGBoostSparkGpuPlugin")

private lazy val root = {
val tmp = path + "/xgboost"
val tmp = parent + "/xgboost"
createDirectory(tmp)
tmp
}

logger.info(s"DiskExternalMemoryIterator createDirectory $root")

// Tasks mapping the path to the Future of caching table
private val taskFutures: mutable.HashMap[String, Future[Boolean]] = mutable.HashMap.empty
private val executor = Executors.newFixedThreadPool(2)
implicit val ec = ExecutionContext.fromExecutor(executor)

private var counter = 0

private def createDirectory(dirPath: String): Unit = {
Expand All @@ -78,20 +92,45 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
}
}

/**
* Cache the table into disk which runs in a separate thread
*
* @param table to be cached
* @param path where to cache the table
*/
private def cacheTableThread(table: Table, path: String): Future[Boolean] = {
Future {
withResource(table) { _ =>
try {
val names = (1 to table.getNumberOfColumns).map(_.toString)
val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build()
withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer =>
writer.write(table)
}
true
} catch {
case e: Throwable =>
throw e
false
}
}
}
}

/**
* Convert the table to file path which will be cached
*
* @param table to be converted
* @return the content
*/
override def convertTable(table: Table): String = {
val names = (1 to table.getNumberOfColumns).map(_.toString)
val options = ArrowIPCWriterOptions.builder().withColumnNames(names: _*).build()
val path = root + "/table_" + counter + "_" + System.nanoTime();
val path = root + "/table_" + counter + "_" + System.nanoTime()
counter += 1
withResource(Table.writeArrowIPCChunked(options, new File(path))) { writer =>
writer.write(table)
}

// Increase the reference count of columnars to avoid being recycled
val newTable = new Table((0 until table.getNumberOfColumns).map(table.getColumn): _*)
val future = cacheTableThread(newTable, path)
taskFutures += (path -> future)
path
}

Expand All @@ -106,18 +145,32 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
}
}

private def checkAndWaitCachingDone(path: String): Unit = {
val futureOpt = taskFutures.get(path)
if (futureOpt.isEmpty) {
throw new RuntimeException(s"Failed to find the caching process for $path")
}
// Wait 6s to check if the caching is done.
// TODO, make it configurable
// If timeout, it's going to throw exception
val success = Await.result(futureOpt.get, 6.seconds)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 seconds is not a lot when data is large and system is busy, why do we need timeout here? Can writing to disk hang?

if (!success) { // Failed to cache
throw new RuntimeException(s"Failed to cache table to $path")
}
}

/**
* Load the path from disk to the Table
*
* @param name to be loaded
* @param path to be loaded
* @return Table
*/
override def loadTable(name: String): Table = {
val file = new File(name)
if (!file.exists()) {
throw new RuntimeException(s"The cache file ${name} doesn't exist" )
}
override def loadTable(path: String): Table = {
val file = new File(path)

try {
checkAndWaitCachingDone(path)

withResource(Table.readArrowIPCChunked(file)) { reader =>
val tables = ArrayBuffer.empty[Table]
closeOnExcept(tables) { tables =>
Expand Down Expand Up @@ -147,6 +200,7 @@ private[spark] class DiskExternalMemoryIterator(val path: String) extends Extern
}

override def close(): Unit = {
executor.shutdown()
buffers.foreach { path =>
val file = new File(path)
if (file.exists()) {
Expand All @@ -169,7 +223,7 @@ private[spark] object ExternalMemory {
*
* The first round iteration gets the input batch that will be
* 1. cached in the external memory
* 2. fed in QuantilDmatrix
* 2. fed in QuantileDMatrix
* The second round iteration returns the cached batch got from external memory.
*
* @param input the spark input iterator
Expand Down
Loading