Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.util

import org.apache.spark.SparkRuntimeException
import java.nio.file.{Path, Paths}

object ArtifactUtils {
Expand All @@ -40,4 +41,17 @@ object ArtifactUtils {
private[sql] def concatenatePaths(basePath: Path, otherPath: String): Path = {
concatenatePaths(basePath, Paths.get(otherPath))
}

/**
* Converts a sequence of exceptions into a single exception by adding all but the first
* exceptions as suppressed exceptions to the first one.
* @param exceptions
* @return
*/
private[sql] def mergeExceptionsWithSuppressed(exceptions: Seq[SparkRuntimeException]): SparkRuntimeException = {
require(exceptions.nonEmpty)
val mainException = exceptions.head
exceptions.drop(1).foreach(mainException.addSuppressed)
mainException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.control.NonFatal
import com.google.common.io.CountingOutputStream
import io.grpc.stub.StreamObserver

import org.apache.spark.SparkRuntimeException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
Expand Down Expand Up @@ -112,19 +113,32 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
* @return
*/
protected def flushStagedArtifacts(): Seq[ArtifactSummary] = {
val failedArtifactExceptions = mutable.ListBuffer[SparkRuntimeException]()

// Non-lazy transformation when using Buffer.
stagedArtifacts.map { artifact =>
// We do not store artifacts that fail the CRC. The failure is reported in the artifact
// summary and it is up to the client to decide whether to retry sending the artifact.
if (artifact.getCrcStatus.contains(true)) {
if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) {
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
val summaries = stagedArtifacts.map { artifact =>
try {
// We do not store artifacts that fail the CRC. The failure is reported in the artifact
// summary and it is up to the client to decide whether to retry sending the artifact.
if (artifact.getCrcStatus.contains(true)) {
if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) {
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
}
}
} catch {
case e: SparkRuntimeException if e.getCondition == "ARTIFACT_ALREADY_EXISTS" =>
failedArtifactExceptions += e
}
artifact.summary()
}.toSeq

if (failedArtifactExceptions.nonEmpty) {
throw ArtifactUtils.mergeExceptionsWithSuppressed(failedArtifactExceptions.toSeq)
}

summaries
}

protected def cleanUpStagedArtifacts(): Unit = Utils.deleteRecursively(stagingDir.toFile)
Expand Down Expand Up @@ -216,6 +230,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
private val fileOut = Files.newOutputStream(stagedPath)
private val countingOut = new CountingOutputStream(fileOut)
private val checksumOut = new CheckedOutputStream(countingOut, new CRC32)
private val overallChecksum = new CRC32()

private val builder = ArtifactSummary.newBuilder().setName(name)
private var artifactSummary: ArtifactSummary = _
Expand All @@ -227,13 +242,17 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr

def getCrcStatus: Option[Boolean] = Option(isCrcSuccess)

def getCrc: Long = overallChecksum.getValue

def write(dataChunk: proto.AddArtifactsRequest.ArtifactChunk): Unit = {
try dataChunk.getData.writeTo(checksumOut)
catch {
case NonFatal(e) =>
close()
throw e
}

overallChecksum.update(dataChunk.getData.toByteArray)
updateCrc(checksumOut.getChecksum.getValue == dataChunk.getCrc)
checksumOut.getChecksum.reset()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import io.grpc.StatusRuntimeException
import io.grpc.protobuf.StatusProto
import io.grpc.stub.StreamObserver

import org.apache.spark.SparkRuntimeException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.sql.connect.ResourceHelper
Expand All @@ -43,6 +44,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
private val CHUNK_SIZE: Int = 32 * 1024

private val sessionId = UUID.randomUUID.toString()
private val sessionKey = SessionKey("c1", sessionId)

class DummyStreamObserver(p: Promise[AddArtifactsResponse])
extends StreamObserver[AddArtifactsResponse] {
Expand All @@ -51,17 +53,31 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
override def onCompleted(): Unit = {}
}

class TestAddArtifactsHandler(responseObserver: StreamObserver[AddArtifactsResponse])
class TestAddArtifactsHandler(responseObserver: StreamObserver[AddArtifactsResponse],
throwIfArtifactExists: Boolean = false)
extends SparkConnectAddArtifactsHandler(responseObserver) {

// Stop the staged artifacts from being automatically deleted
override protected def cleanUpStagedArtifacts(): Unit = {}

private val finalArtifacts = mutable.Buffer.empty[String]
private val artifactChecksums: mutable.Map[String, Long] = mutable.Map.empty

// Record the artifacts that are sent out for final processing.
override protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = {
// Throw if artifact already exists and has different checksum
// This mocks the behavior of ArtifactManager.addArtifact without comparing the entire file
if (throwIfArtifactExists
&& finalArtifacts.contains(artifact.name)
&& artifact.getCrc != artifactChecksums(artifact.name)) {
throw new SparkRuntimeException(
"ARTIFACT_ALREADY_EXISTS",
Map("normalizedRemoteRelativePath" -> artifact.name)
)
}

finalArtifacts.append(artifact.name)
artifactChecksums += (artifact.name -> artifact.getCrc)
}

def getFinalArtifacts: Seq[String] = finalArtifacts.toSeq
Expand Down Expand Up @@ -418,4 +434,80 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
}
}


def addSingleChunkArtifact(
handler: SparkConnectAddArtifactsHandler,
sessionKey: SessionKey,
name: String,
artifactPath: Path): Unit = {
val dataChunks = getDataChunks(artifactPath)
assert(dataChunks.size == 1)
val bytes = dataChunks.head
val context = proto.UserContext
.newBuilder()
.setUserId(sessionKey.userId)
.build()
val fileNameNoExtension = artifactPath.getFileName.toString.split('.').head
val singleChunkArtifact = proto.AddArtifactsRequest.SingleChunkArtifact
.newBuilder()
.setName(name)
.setData(
proto.AddArtifactsRequest.ArtifactChunk
.newBuilder()
.setData(bytes)
.setCrc(getCrcValues(crcPath.resolve(fileNameNoExtension + ".txt")).head)
.build())
.build()

val singleChunkArtifactRequest = AddArtifactsRequest
.newBuilder()
.setSessionId(sessionKey.sessionId)
.setUserContext(context)
.setBatch(
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
.build()

handler.onNext(singleChunkArtifactRequest)
}

test("All artifacts are added, even if some fail") {
val promise = Promise[AddArtifactsResponse]()
val handler = new TestAddArtifactsHandler(new DummyStreamObserver(promise),
throwIfArtifactExists = true)
try {
val name1 = "jars/dummy1.jar"
val name2 = "jars/dummy2.jar"
val name3 = "jars/dummy3.jar"

val artifactPath1 = inputFilePath.resolve("smallClassFile.class")
val artifactPath2 = inputFilePath.resolve("smallJar.jar")

assume(artifactPath1.toFile.exists)
addSingleChunkArtifact(handler, sessionKey, name1, artifactPath1)
addSingleChunkArtifact(handler, sessionKey, name3, artifactPath1)

val e = intercept[StatusRuntimeException] {
addSingleChunkArtifact(handler, sessionKey, name1, artifactPath2)
addSingleChunkArtifact(handler, sessionKey, name2, artifactPath1)
addSingleChunkArtifact(handler, sessionKey, name3, artifactPath2)
handler.onCompleted()
}

// Both artifacts should be added, despite exception
assert(handler.getFinalArtifacts.contains(name1))
assert(handler.getFinalArtifacts.contains(name2))
assert(handler.getFinalArtifacts.contains(name3))

assert(e.getStatus.getCode == Code.INTERNAL)
val statusProto = StatusProto.fromThrowable(e)
assert(statusProto.getDetailsCount == 1)
val details = statusProto.getDetails(0)
val info = details.unpack(classOf[ErrorInfo])

assert(e.getMessage.contains("ARTIFACT_ALREADY_EXISTS"))
assert(info.getMetadataMap().get("messageParameters").contains(name1))
} finally {
handler.forceCleanUp()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

Expand Down Expand Up @@ -266,28 +267,39 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
* they are from a permanent location.
*/
private[sql] def addLocalArtifacts(artifacts: Seq[Artifact]): Unit = {
val failedArtifactExceptions = ListBuffer[SparkRuntimeException]()

artifacts.foreach { artifact =>
artifact.storage match {
case d: Artifact.LocalFile =>
addArtifact(
artifact.path,
d.path,
fragment = None,
deleteStagedFile = false)
case d: Artifact.InMemory =>
val tempDir = Utils.createTempDir().toPath
val tempFile = tempDir.resolve(artifact.path.getFileName)
val outStream = Files.newOutputStream(tempFile)
Utils.tryWithSafeFinallyAndFailureCallbacks {
d.stream.transferTo(outStream)
addArtifact(artifact.path, tempFile, fragment = None)
}(finallyBlock = {
outStream.close()
})
case _ =>
throw SparkException.internalError(s"Unsupported artifact storage: ${artifact.storage}")
try {
artifact.storage match {
case d: Artifact.LocalFile =>
addArtifact(
artifact.path,
d.path,
fragment = None,
deleteStagedFile = false)
case d: Artifact.InMemory =>
val tempDir = Utils.createTempDir().toPath
val tempFile = tempDir.resolve(artifact.path.getFileName)
val outStream = Files.newOutputStream(tempFile)
Utils.tryWithSafeFinallyAndFailureCallbacks {
d.stream.transferTo(outStream)
addArtifact(artifact.path, tempFile, fragment = None)
}(finallyBlock = {
outStream.close()
})
case _ =>
throw SparkException.internalError(s"Unsupported artifact storage: ${artifact.storage}")
}
} catch {
case e: SparkRuntimeException if e.getCondition == "ARTIFACT_ALREADY_EXISTS" =>
failedArtifactExceptions += e
}
}

if (failedArtifactExceptions.nonEmpty) {
Copy link
Contributor

@heyihong heyihong Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error handling and suppression logic seems to be duplicated in both ArtifactManager.scala and SparkConnectAddArtifactsHandler.scala.

 if (failedArtifactExceptions.nonEmpty) {
      val exception = failedArtifactExceptions.head
      failedArtifactExceptions.drop(1).foreach(exception.addSuppressed(_))
      throw exception
    }
}

I was wondering whether it makes sense to introduce a small utility to handle Seq[Try[...]] instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArtifactUtils would be a good place for this utility method

throw ArtifactUtils.mergeExceptionsWithSuppressed(failedArtifactExceptions.toSeq)
}
}

def classloader: ClassLoader = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import java.io.File
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, Paths}

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException}
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.Artifact
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -346,6 +347,78 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}

test("Add multiple artifacts to local session and check if all are added despite exception") {
val copyDir = Utils.createTempDir().toPath
Utils.copyDirectory(artifactPath.toFile, copyDir.toFile)

val artifact1Path = "my/custom/pkg/artifact1.jar"
val artifact2Path = "my/custom/pkg/artifact2.jar"
val targetPath = Paths.get(artifact1Path)
val targetPath2 = Paths.get(artifact2Path)

val classPath1 = copyDir.resolve("Hello.class")
val classPath2 = copyDir.resolve("udf_noA.jar")
assume(artifactPath.resolve("Hello.class").toFile.exists)
assume(artifactPath.resolve("smallClassFile.class").toFile.exists)

val artifact1 = Artifact.newArtifactFromExtension(
targetPath.getFileName.toString,
targetPath,
new Artifact.LocalFile(Paths.get(classPath1.toString)))

val alreadyExistingArtifact = Artifact.newArtifactFromExtension(
targetPath2.getFileName.toString,
targetPath,
new Artifact.LocalFile(Paths.get(classPath2.toString)))

val artifact2 = Artifact.newArtifactFromExtension(
targetPath2.getFileName.toString,
targetPath2,
new Artifact.LocalFile(Paths.get(classPath2.toString)))

spark.artifactManager.addLocalArtifacts(Seq(artifact1))

val ex = intercept[SparkRuntimeException] {
spark.artifactManager.addLocalArtifacts(
Seq(alreadyExistingArtifact, artifact2, alreadyExistingArtifact))
}

checkError(
exception = ex,
condition = "ARTIFACT_ALREADY_EXISTS",
parameters = Map("normalizedRemoteRelativePath" -> s"jars/${targetPath.toString}"),
)

assert(ex.getSuppressed.length == 1)
assert(ex.getSuppressed.head.isInstanceOf[SparkRuntimeException])
val suppressed = ex.getSuppressed.head.asInstanceOf[SparkRuntimeException]

checkError(
exception = suppressed,
condition = "ARTIFACT_ALREADY_EXISTS",
parameters = Map("normalizedRemoteRelativePath" -> s"jars/${targetPath.toString}"),
)

// Artifact1 should have been added
val expectedFile1 = ArtifactManager.artifactRootDirectory
.resolve(s"$sessionUUID/jars/$artifact1Path")
.toFile
assert(expectedFile1.exists())

// Artifact2 should have been added despite exception
val expectedFile2 = ArtifactManager.artifactRootDirectory
.resolve(s"$sessionUUID/jars/$artifact2Path")
.toFile
assert(expectedFile2.exists())

// Cleanup
artifactManager.cleanUpResourcesForTesting()
val sessionDir = ArtifactManager.artifactRootDirectory.resolve(sessionUUID).toFile

assert(!expectedFile1.exists())
assert(!sessionDir.exists())
}

test("Added artifact can be loaded by the current SparkSession") {
val path = artifactPath.resolve("IntSumUdf.class")
assume(path.toFile.exists)
Expand Down