diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1cd63849de636..fe74f034e2933 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -1022,11 +1022,19 @@ private[spark] class SparkSubmit extends Logging { e } + var exitCode: Int = 1 try { app.start(childArgs.toArray, sparkConf) + exitCode = 0 } catch { case t: Throwable => - throw findCause(t) + val cause = findCause(t) + cause match { + case e: SparkUserAppException => + exitCode = e.exitCode + case _ => + } + throw cause } finally { if (args.master.startsWith("k8s") && !isShell(args.primaryResource) && !isSqlShell(args.mainClass) && !isThriftServer(args.mainClass) && @@ -1037,6 +1045,12 @@ private[spark] class SparkSubmit extends Logging { case e: Throwable => logError("Failed to close SparkContext", e) } } + if (sparkConf.get(SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT)) { + logInfo( + log"Calling System.exit() with exit code ${MDC(LogKeys.EXIT_CODE, exitCode)} " + + log"because ${MDC(LogKeys.CONFIG, SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT.key)}=true") + exitFn(exitCode) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index b6a8b24f6fc77..0bee708bca3c7 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2311,6 +2311,15 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT = + ConfigBuilder("spark.submit.callSystemExitOnMainExit") + .doc("If true, SparkSubmit will call System.exit() to initiate JVM shutdown once the " + + "user's main method has exited. This can be useful in cases where non-daemon JVM " + + "threads might otherwise prevent the JVM from shutting down on its own.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + private[spark] val SCHEDULER_ALLOCATION_FILE = ConfigBuilder("spark.scheduler.allocation.file") .version("0.8.1") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index dbecad2df7689..0db8ba785fcbc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1593,6 +1593,44 @@ class SparkSubmitSuite runSparkSubmit(argsSuccess, expectFailure = false)) } + test("spark.submit.callSystemExitOnMainExit returns non-zero exit code on unclean main exit") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", MainThrowsUncaughtExceptionSparkApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--conf", s"${SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT.key}=true", + unusedJar.toString + ) + assertResult(1)(runSparkSubmit(args, expectFailure = true)) + } + + test("spark.submit.callSystemExitOnMainExit calls system exit on clean main exit") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", NonDaemonThreadSparkApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--conf", s"${SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT.key}=true", + unusedJar.toString + ) + // With SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT set to false, the non-daemon thread will + // prevent the JVM from beginning shutdown and the following call will fail with a + // timeout: + assertResult(0)(runSparkSubmit(args)) + } + + test("spark.submit.callSystemExitOnMainExit with main that explicitly calls System.exit") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", + MainExplicitlyCallsSystemExit3SparkApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--conf", s"${SUBMIT_CALL_SYSTEM_EXIT_ON_MAIN_EXIT.key}=true", + unusedJar.toString + ) + // This main class explicitly exits with System.exit(3), hence this expected exit code: + assertResult(3)(runSparkSubmit(args, expectFailure = true)) + } + private def testRemoteResources( enableHttpFs: Boolean, forceDownloadSchemes: Seq[String] = Nil): Unit = { @@ -1876,6 +1914,34 @@ object SimpleApplicationTest { } } +object MainThrowsUncaughtExceptionSparkApplicationTest { + def main(args: Array[String]): Unit = { + throw new Exception("User exception") + } +} + +object NonDaemonThreadSparkApplicationTest { + def main(args: Array[String]): Unit = { + val nonDaemonThread: Thread = new Thread { + override def run(): Unit = { + while (true) { + Thread.sleep(1000) + } + } + } + nonDaemonThread.setDaemon(false) + nonDaemonThread.setName("Non-Daemon-Thread") + nonDaemonThread.start() + // Fall off the end of the main method. + } +} + +object MainExplicitlyCallsSystemExit3SparkApplicationTest { + def main(args: Array[String]): Unit = { + System.exit(3) + } +} + object UserClasspathFirstTest { def main(args: Array[String]): Unit = { val ccl = Thread.currentThread().getContextClassLoader()