Skip to content

Commit 5885ce0

Browse files
Hendrik HuebnerHendrikHuebner
authored andcommitted
Improve exception handling when adding artifacts
1 parent 77413d4 commit 5885ce0

File tree

4 files changed

+223
-30
lines changed

4 files changed

+223
-30
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import scala.util.control.NonFatal
2626
import com.google.common.io.CountingOutputStream
2727
import io.grpc.stub.StreamObserver
2828

29+
import org.apache.spark.SparkRuntimeException
2930
import org.apache.spark.connect.proto
3031
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
3132
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
@@ -112,19 +113,34 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
112113
* @return
113114
*/
114115
protected def flushStagedArtifacts(): Seq[ArtifactSummary] = {
116+
val failedArtifactExceptions = mutable.ListBuffer[SparkRuntimeException]()
117+
115118
// Non-lazy transformation when using Buffer.
116-
stagedArtifacts.map { artifact =>
117-
// We do not store artifacts that fail the CRC. The failure is reported in the artifact
118-
// summary and it is up to the client to decide whether to retry sending the artifact.
119-
if (artifact.getCrcStatus.contains(true)) {
120-
if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) {
121-
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
122-
} else {
123-
addStagedArtifactToArtifactManager(artifact)
119+
val summaries = stagedArtifacts.map { artifact =>
120+
try {
121+
// We do not store artifacts that fail the CRC. The failure is reported in the artifact
122+
// summary and it is up to the client to decide whether to retry sending the artifact.
123+
if (artifact.getCrcStatus.contains(true)) {
124+
if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) {
125+
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
126+
} else {
127+
addStagedArtifactToArtifactManager(artifact)
128+
}
124129
}
130+
} catch {
131+
case e: SparkRuntimeException if e.getCondition == "ARTIFACT_ALREADY_EXISTS" =>
132+
failedArtifactExceptions += e
125133
}
126134
artifact.summary()
127135
}.toSeq
136+
137+
if (failedArtifactExceptions.nonEmpty) {
138+
val exception = failedArtifactExceptions.head
139+
failedArtifactExceptions.drop(1).foreach(exception.addSuppressed(_))
140+
throw exception
141+
}
142+
143+
summaries
128144
}
129145

130146
protected def cleanUpStagedArtifacts(): Unit = Utils.deleteRecursively(stagingDir.toFile)
@@ -216,6 +232,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
216232
private val fileOut = Files.newOutputStream(stagedPath)
217233
private val countingOut = new CountingOutputStream(fileOut)
218234
private val checksumOut = new CheckedOutputStream(countingOut, new CRC32)
235+
private val overallChecksum = new CRC32()
219236

220237
private val builder = ArtifactSummary.newBuilder().setName(name)
221238
private var artifactSummary: ArtifactSummary = _
@@ -227,13 +244,17 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
227244

228245
def getCrcStatus: Option[Boolean] = Option(isCrcSuccess)
229246

247+
def getCrc: Long = overallChecksum.getValue
248+
230249
def write(dataChunk: proto.AddArtifactsRequest.ArtifactChunk): Unit = {
231250
try dataChunk.getData.writeTo(checksumOut)
232251
catch {
233252
case NonFatal(e) =>
234253
close()
235254
throw e
236255
}
256+
257+
overallChecksum.update(dataChunk.getData.toByteArray)
237258
updateCrc(checksumOut.getChecksum.getValue == dataChunk.getCrc)
238259
checksumOut.getChecksum.reset()
239260
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ import io.grpc.StatusRuntimeException
3232
import io.grpc.protobuf.StatusProto
3333
import io.grpc.stub.StreamObserver
3434

35+
import org.apache.spark.SparkRuntimeException
3536
import org.apache.spark.connect.proto
3637
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
3738
import org.apache.spark.sql.connect.ResourceHelper
3839
import org.apache.spark.sql.test.SharedSparkSession
3940
import org.apache.spark.util.ThreadUtils
4041

41-
class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
42+
class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
4243

4344
private val CHUNK_SIZE: Int = 32 * 1024
4445

4546
private val sessionId = UUID.randomUUID.toString()
47+
private val sessionKey = SessionKey("c1", sessionId)
4648

4749
class DummyStreamObserver(p: Promise[AddArtifactsResponse])
4850
extends StreamObserver[AddArtifactsResponse] {
@@ -51,17 +53,31 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
5153
override def onCompleted(): Unit = {}
5254
}
5355

54-
class TestAddArtifactsHandler(responseObserver: StreamObserver[AddArtifactsResponse])
56+
class TestAddArtifactsHandler(responseObserver: StreamObserver[AddArtifactsResponse],
57+
throwIfArtifactExists: Boolean = false)
5558
extends SparkConnectAddArtifactsHandler(responseObserver) {
5659

5760
// Stop the staged artifacts from being automatically deleted
5861
override protected def cleanUpStagedArtifacts(): Unit = {}
5962

6063
private val finalArtifacts = mutable.Buffer.empty[String]
64+
private val artifactChecksums: mutable.Map[String, Long] = mutable.Map.empty
6165

6266
// Record the artifacts that are sent out for final processing.
6367
override protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = {
68+
// Throw if artifact already exists and has different checksum
69+
// This mocks the behavior of ArtifactManager.addArtifact without comparing the entire file
70+
if (throwIfArtifactExists
71+
&& finalArtifacts.contains(artifact.name)
72+
&& artifact.getCrc != artifactChecksums(artifact.name)) {
73+
throw new SparkRuntimeException(
74+
"ARTIFACT_ALREADY_EXISTS",
75+
Map("normalizedRemoteRelativePath" -> artifact.name)
76+
)
77+
}
78+
6479
finalArtifacts.append(artifact.name)
80+
artifactChecksums += (artifact.name -> artifact.getCrc)
6581
}
6682

6783
def getFinalArtifacts: Seq[String] = finalArtifacts.toSeq
@@ -418,4 +434,80 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
418434
}
419435
}
420436

437+
438+
def addSingleChunkArtifact(
439+
handler: SparkConnectAddArtifactsHandler,
440+
sessionKey: SessionKey,
441+
name: String,
442+
artifactPath: Path): Unit = {
443+
val dataChunks = getDataChunks(artifactPath)
444+
assert(dataChunks.size == 1)
445+
val bytes = dataChunks.head
446+
val context = proto.UserContext
447+
.newBuilder()
448+
.setUserId(sessionKey.userId)
449+
.build()
450+
val fileNameNoExtension = artifactPath.getFileName.toString.split('.').head
451+
val singleChunkArtifact = proto.AddArtifactsRequest.SingleChunkArtifact
452+
.newBuilder()
453+
.setName(name)
454+
.setData(
455+
proto.AddArtifactsRequest.ArtifactChunk
456+
.newBuilder()
457+
.setData(bytes)
458+
.setCrc(getCrcValues(crcPath.resolve(fileNameNoExtension + ".txt")).head)
459+
.build())
460+
.build()
461+
462+
val singleChunkArtifactRequest = AddArtifactsRequest
463+
.newBuilder()
464+
.setSessionId(sessionKey.sessionId)
465+
.setUserContext(context)
466+
.setBatch(
467+
proto.AddArtifactsRequest.Batch.newBuilder().addArtifacts(singleChunkArtifact).build())
468+
.build()
469+
470+
handler.onNext(singleChunkArtifactRequest)
471+
}
472+
473+
test("All artifacts are added, even if some fail") {
474+
val promise = Promise[AddArtifactsResponse]()
475+
val handler = new TestAddArtifactsHandler(new DummyStreamObserver(promise),
476+
throwIfArtifactExists = true)
477+
try {
478+
val name1 = "jars/dummy1.jar"
479+
val name2 = "jars/dummy2.jar"
480+
val name3 = "jars/dummy3.jar"
481+
482+
val artifactPath1 = inputFilePath.resolve("smallClassFile.class")
483+
val artifactPath2 = inputFilePath.resolve("smallJar.jar")
484+
485+
assume(artifactPath1.toFile.exists)
486+
addSingleChunkArtifact(handler, sessionKey, name1, artifactPath1)
487+
addSingleChunkArtifact(handler, sessionKey, name3, artifactPath1)
488+
489+
val e = intercept[StatusRuntimeException] {
490+
addSingleChunkArtifact(handler, sessionKey, name1, artifactPath2)
491+
addSingleChunkArtifact(handler, sessionKey, name2, artifactPath1)
492+
addSingleChunkArtifact(handler, sessionKey, name3, artifactPath2)
493+
handler.onCompleted()
494+
}
495+
496+
// Both artifacts should be added, despite exception
497+
assert(handler.getFinalArtifacts.contains(name1))
498+
assert(handler.getFinalArtifacts.contains(name2))
499+
assert(handler.getFinalArtifacts.contains(name3))
500+
501+
assert(e.getStatus.getCode == Code.INTERNAL)
502+
val statusProto = StatusProto.fromThrowable(e)
503+
assert(statusProto.getDetailsCount == 1)
504+
val details = statusProto.getDetails(0)
505+
val info = details.unpack(classOf[ErrorInfo])
506+
507+
assert(e.getMessage.contains("ARTIFACT_ALREADY_EXISTS"))
508+
assert(info.getMetadataMap().get("messageParameters").contains(name1))
509+
} finally {
510+
handler.forceCleanUp()
511+
}
512+
}
421513
}

sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
2525
import java.util.concurrent.CopyOnWriteArrayList
2626
import java.util.concurrent.atomic.AtomicBoolean
2727

28+
import scala.collection.mutable.ListBuffer
2829
import scala.jdk.CollectionConverters._
2930
import scala.reflect.ClassTag
3031

@@ -266,28 +267,41 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
266267
* they are from a permanent location.
267268
*/
268269
private[sql] def addLocalArtifacts(artifacts: Seq[Artifact]): Unit = {
270+
val failedArtifactExceptions = ListBuffer[RuntimeException]()
271+
269272
artifacts.foreach { artifact =>
270-
artifact.storage match {
271-
case d: Artifact.LocalFile =>
272-
addArtifact(
273-
artifact.path,
274-
d.path,
275-
fragment = None,
276-
deleteStagedFile = false)
277-
case d: Artifact.InMemory =>
278-
val tempDir = Utils.createTempDir().toPath
279-
val tempFile = tempDir.resolve(artifact.path.getFileName)
280-
val outStream = Files.newOutputStream(tempFile)
281-
Utils.tryWithSafeFinallyAndFailureCallbacks {
282-
d.stream.transferTo(outStream)
283-
addArtifact(artifact.path, tempFile, fragment = None)
284-
}(finallyBlock = {
285-
outStream.close()
286-
})
287-
case _ =>
288-
throw SparkException.internalError(s"Unsupported artifact storage: ${artifact.storage}")
273+
try {
274+
artifact.storage match {
275+
case d: Artifact.LocalFile =>
276+
addArtifact(
277+
artifact.path,
278+
d.path,
279+
fragment = None,
280+
deleteStagedFile = false)
281+
case d: Artifact.InMemory =>
282+
val tempDir = Utils.createTempDir().toPath
283+
val tempFile = tempDir.resolve(artifact.path.getFileName)
284+
val outStream = Files.newOutputStream(tempFile)
285+
Utils.tryWithSafeFinallyAndFailureCallbacks {
286+
d.stream.transferTo(outStream)
287+
addArtifact(artifact.path, tempFile, fragment = None)
288+
}(finallyBlock = {
289+
outStream.close()
290+
})
291+
case _ =>
292+
throw SparkException.internalError(s"Unsupported artifact storage: ${artifact.storage}")
293+
}
294+
} catch {
295+
case e: SparkRuntimeException =>
296+
failedArtifactExceptions += e
289297
}
290298
}
299+
300+
if (failedArtifactExceptions.nonEmpty) {
301+
val exception = failedArtifactExceptions.head
302+
failedArtifactExceptions.drop(1).foreach(exception.addSuppressed(_))
303+
throw exception
304+
}
291305
}
292306

293307
def classloader: ClassLoader = synchronized {

sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ import java.io.File
2020
import java.nio.charset.StandardCharsets
2121
import java.nio.file.{Files, Path, Paths}
2222

23-
import org.apache.spark.{SparkConf, SparkException}
23+
import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException}
2424
import org.apache.spark.metrics.source.CodegenMetrics
25+
import org.apache.spark.sql.Artifact
2526
import org.apache.spark.sql.classic.SparkSession
2627
import org.apache.spark.sql.functions.col
2728
import org.apache.spark.sql.internal.SQLConf
@@ -346,6 +347,71 @@ class ArtifactManagerSuite extends SharedSparkSession {
346347
}
347348
}
348349

350+
test("Add multiple artifacts to local session and check if all are added despite exception") {
351+
val copyDir = Utils.createTempDir().toPath
352+
Utils.copyDirectory(artifactPath.toFile, copyDir.toFile)
353+
354+
val artifact1Path = "my/custom/pkg/artifact1.jar"
355+
val artifact2Path = "my/custom/pkg/artifact2.jar"
356+
val targetPath = Paths.get(artifact1Path)
357+
val targetPath2 = Paths.get(artifact2Path)
358+
359+
val classPath1 = copyDir.resolve("Hello.class")
360+
val classPath2 = copyDir.resolve("smallJar.jar")
361+
assume(artifactPath.resolve("Hello.class").toFile.exists)
362+
assume(artifactPath.resolve("smallClassFile.class").toFile.exists)
363+
364+
val artifact1 = Artifact.newArtifactFromExtension(
365+
targetPath.getFileName.toString,
366+
targetPath,
367+
new Artifact.LocalFile(Paths.get(classPath1.toString)))
368+
369+
val alreadyExistingArtifact = Artifact.newArtifactFromExtension(
370+
targetPath2.getFileName.toString,
371+
targetPath,
372+
new Artifact.LocalFile(Paths.get(classPath2.toString)))
373+
374+
val artifact2 = Artifact.newArtifactFromExtension(
375+
targetPath2.getFileName.toString,
376+
targetPath2,
377+
new Artifact.LocalFile(Paths.get(classPath2.toString)))
378+
379+
spark.artifactManager.addLocalArtifacts(Seq(artifact1))
380+
381+
val exception = intercept[SparkRuntimeException] {
382+
spark.artifactManager.addLocalArtifacts(
383+
Seq(alreadyExistingArtifact, artifact2, alreadyExistingArtifact))
384+
}
385+
386+
// Validate exception: Should be ARTIFACT_ALREADY_EXISTS and have one suppressed exception
387+
assert(exception.getCondition == "ARTIFACT_ALREADY_EXISTS",
388+
s"Expected ARTIFACT_ALREADY_EXISTS but got: ${exception.getCondition}")
389+
390+
assert(exception.getSuppressed.length == 1)
391+
assert(exception.getSuppressed.head.isInstanceOf[SparkRuntimeException])
392+
val suppressed = exception.getSuppressed.head.asInstanceOf[SparkRuntimeException]
393+
assert(suppressed.getCondition == "ARTIFACT_ALREADY_EXISTS")
394+
395+
// Artifact1 should have been added
396+
val expectedFile1 = ArtifactManager.artifactRootDirectory
397+
.resolve(s"$sessionUUID/jars/$artifact1Path")
398+
.toFile
399+
assert(expectedFile1.exists())
400+
401+
// Artifact2 should have been added despite exception
402+
val expectedFile2 = ArtifactManager.artifactRootDirectory
403+
.resolve(s"$sessionUUID/jars/$artifact2Path")
404+
.toFile
405+
assert(expectedFile2.exists())
406+
407+
// Cleanup
408+
artifactManager.cleanUpResourcesForTesting()
409+
val sessionDir = ArtifactManager.artifactRootDirectory.resolve(sessionUUID).toFile
410+
411+
assert(!expectedFile1.exists())
412+
assert(!sessionDir.exists())
413+
}
414+
349415
test("Added artifact can be loaded by the current SparkSession") {
350416
val path = artifactPath.resolve("IntSumUdf.class")
351417
assume(path.toFile.exists)

0 commit comments

Comments
 (0)