diff --git a/integration-tests/kyuubi-kubernetes-it/src/test/scala/org/apache/kyuubi/kubernetes/test/spark/SparkOnKubernetesTestsSuite.scala b/integration-tests/kyuubi-kubernetes-it/src/test/scala/org/apache/kyuubi/kubernetes/test/spark/SparkOnKubernetesTestsSuite.scala index 562ee63799a..5518246f62c 100644 --- a/integration-tests/kyuubi-kubernetes-it/src/test/scala/org/apache/kyuubi/kubernetes/test/spark/SparkOnKubernetesTestsSuite.scala +++ b/integration-tests/kyuubi-kubernetes-it/src/test/scala/org/apache/kyuubi/kubernetes/test/spark/SparkOnKubernetesTestsSuite.scala @@ -254,4 +254,56 @@ class KyuubiOperationKubernetesClusterClusterModeSuite sessionHandle.identifier.toString) assert(!failKillResponse._1) } + test( + "If spark batch reach timeout, it should have associated Kyuubi Application Operation be " + + "in TIMEOUT state with Spark Driver Engine be in NOT_FOUND state!") { + import scala.collection.JavaConverters._ + // Configure a very small submit timeout to trigger the timeout => 1000ms! + val originalTimeout = conf.get(ENGINE_KUBERNETES_SUBMIT_TIMEOUT) + conf.set(ENGINE_KUBERNETES_SUBMIT_TIMEOUT, 1000L) + + try { + // Prepare a metadata row only (INITIALIZED), without actually launching a Spark driver + val batchId = UUID.randomUUID().toString + val batchRequest = newSparkBatchRequest(conf.getAll ++ Map( + KYUUBI_BATCH_ID_KEY -> batchId)) + + val user = "test-user" + val ipAddress = "test-ip" + + // Insert the metadata so that subsequent update can find this record + sessionManager.initializeBatchState( + user, + ipAddress, + batchRequest.getConf.asScala.toMap, + batchRequest) + + // Create a fresh KubernetesApplicationOperation that can trigger update + // to metadata upon timeout! + val operation = new KubernetesApplicationOperation + operation.initialize(conf, sessionManager.metadataManager) + + // Use a submitTime far enough in the past to exceed the timeout + val submitTime = Some(System.currentTimeMillis() - 10000L) + + // No driver pod exists for this random batch id, so this should hit the timeout path + operation.getApplicationInfoByTag( + appMgrInfo, + batchId, + Some(user), + submitTime) + + eventually(timeout(30.seconds), interval(200.milliseconds)) { + val mdOpt = sessionManager.getBatchMetadata(batchId) + assert(mdOpt.isDefined) + val md = mdOpt.get + // Verify metadata reflects TIMEOUT and NOT_FOUND as set by the timeout handling + assert(md.state == org.apache.kyuubi.operation.OperationState.TIMEOUT.toString) + assert(md.engineState == NOT_FOUND.toString) + } + } finally { + // restore back original engine submit time out for kyuubi batch job submission! + conf.set(ENGINE_KUBERNETES_SUBMIT_TIMEOUT, originalTimeout) + } + } } diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 65864986552..4f2efc7d9b2 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -1910,6 +1910,14 @@ object KyuubiConf { " Kyuubi instances.") .timeConf .createWithDefault(Duration.ofSeconds(20).toMillis) + + val BATCH_SESSIONS_RECOVERY_SIZE: ConfigEntry[Int] = + buildConf("kyuubi.batch.sessions.recovery.size") + .serverOnly + .internal + .doc("The size per batch of kyuubi batch metadata records to fetch and create associated kyuubi sessions at a time for recovery upon restart of kyuubi server") + .intConf + .createWithDefault(10) val BATCH_INTERNAL_REST_CLIENT_CONNECT_TIMEOUT: ConfigEntry[Long] = buildConf("kyuubi.batch.internal.rest.client.connect.timeout") diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/KubernetesApplicationOperation.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/KubernetesApplicationOperation.scala index 703d223d5bc..b5a40d6b53d 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/KubernetesApplicationOperation.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/engine/KubernetesApplicationOperation.scala @@ -322,7 +322,30 @@ class KubernetesApplicationOperation extends ApplicationOperation with Logging { if (elapsedTime > submitTimeout) { error(s"Can't find target driver pod by ${toLabel(tag)}, " + s"elapsed time: ${elapsedTime}ms exceeds ${submitTimeout}ms.") - ApplicationInfo.NOT_FOUND + val errorMsg = + s"Driver pod not found for job with kyuubi-unique-tag: $tag after $elapsedTime ms " + + s"(submit-timeout: $submitTimeout ms)" + /* Update the metadata store to mark this + operation as timed out and the Spark driver + engine as not found. + This prevents the restarted Kyuubi server + from repeatedly polling for this + batch job's status. + */ + try { + metadataManager.foreach(_.updateMetadata( + org.apache.kyuubi.server.metadata.api.Metadata( + identifier = tag, + state = org.apache.kyuubi.operation.OperationState.TIMEOUT.toString, + engineState = ApplicationState.NOT_FOUND.toString, + engineError = Some(errorMsg), + endTime = System.currentTimeMillis()))) + } catch { + case NonFatal(e) => + warn(s"Failed to update metadata for spark job with kyuubi-unique-tag label:" + + s"$tag after submit timeout reached: ${e.getMessage}") + } + appInfo } else { warn(s"Waiting for driver pod with ${toLabel(tag)} to be created, " + s"elapsed time: ${elapsedTime}ms, return UNKNOWN status") diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/KyuubiRestFrontendService.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/KyuubiRestFrontendService.scala index 787ac0b0473..f153b3fc115 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/KyuubiRestFrontendService.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/KyuubiRestFrontendService.scala @@ -181,36 +181,51 @@ class KyuubiRestFrontendService(override val serverable: Serverable) @VisibleForTesting private[kyuubi] def recoverBatchSessions(): Unit = withBatchRecoveryLockRequired { val recoveryNumThreads = conf.get(METADATA_RECOVERY_THREADS) + val recoveryBatchSize: Int = conf.get(BATCH_SESSIONS_RECOVERY_SIZE) val batchRecoveryExecutor = ThreadUtils.newDaemonFixedThreadPool(recoveryNumThreads, "batch-recovery-executor") try { - val batchSessionsToRecover = sessionManager.getBatchSessionsToRecover(connectionUrl) - val pendingRecoveryTasksCount = new AtomicInteger(0) - val tasks = batchSessionsToRecover.flatMap { batchSession => - val batchId = batchSession.batchJobSubmissionOp.batchId - try { - val task: Future[Unit] = batchRecoveryExecutor.submit(() => - Utils.tryLogNonFatalError(sessionManager.openBatchSession(batchSession))) - Some(task -> batchId) - } catch { - case e: Throwable => - error(s"Error while submitting batch[$batchId] for recovery", e) - None - } - } + val offset: Int = 0 + val shouldFetchRemainingBatchSessions: Boolean = true + val totalBatchRecovered: Int = 0 + while(shouldFetchRemainingBatchSessions) { + val batchSessionsToRecover = sessionManager.getBatchSessionsToRecover(connectionUrl, offset, recoveryBatchSize) + if(batchSessionsToRecover.length > 0){ + val pendingRecoveryTasksCount = new AtomicInteger(0) + val tasks = batchSessionsToRecover.flatMap { batchSession => + val batchId = batchSession.batchJobSubmissionOp.batchId + try { + val task: Future[Unit] = batchRecoveryExecutor.submit(() => + Utils.tryLogNonFatalError(sessionManager.openBatchSession(batchSession))) + Some(task -> batchId) + } catch { + case e: Throwable => + error(s"Error while submitting batch[$batchId] for recovery", e) + None + } + } - pendingRecoveryTasksCount.addAndGet(tasks.size) + pendingRecoveryTasksCount.addAndGet(tasks.size) - tasks.foreach { case (task, batchId) => - try { - task.get() - } catch { - case e: Throwable => - error(s"Error while recovering batch[$batchId]", e) - } finally { - val pendingTasks = pendingRecoveryTasksCount.decrementAndGet() - info(s"Batch[$batchId] recovery task terminated, current pending tasks $pendingTasks") + tasks.foreach { case (task, batchId) => + try { + task.get() + } catch { + case e: Throwable => + error(s"Error while recovering batch[$batchId]", e) + } finally { + val pendingTasks = pendingRecoveryTasksCount.decrementAndGet() + info(s"Batch[$batchId] recovery task terminated, current pending tasks $pendingTasks") + } + } + totalBatchRecovered += batchSessionsToRecover.length + offset += batchSessionsToRecover.length + } + else { + shouldFetchRemainingBatchSessions = false + info(s"No more batches left to recover from metadata store.") } + info(s"Recovered $totalBatchRecovered batches total so far successfully from metadata store.") } } finally { ThreadUtils.shutdown(batchRecoveryExecutor) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiSessionManager.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiSessionManager.scala index 344da0e71e8..07417ea4eeb 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiSessionManager.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/session/KyuubiSessionManager.scala @@ -313,13 +313,13 @@ class KyuubiSessionManager private (name: String) extends SessionManager(name) { startEngineAliveChecker() } - def getBatchSessionsToRecover(kyuubiInstance: String): Seq[KyuubiBatchSession] = { + def getBatchSessionsToRecover(kyuubiInstance: String, offset: Int, batchSize: Int): Seq[KyuubiBatchSession] = { Seq(OperationState.PENDING, OperationState.RUNNING).flatMap { stateToRecover => metadataManager.map(_.getBatchesRecoveryMetadata( stateToRecover.toString, kyuubiInstance, - 0, - Int.MaxValue).map { metadata => + offset, + batchSize).map { metadata => createBatchSessionFromRecovery(metadata) }).getOrElse(Seq.empty) }