diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b2f185bc590f..5966f96518d0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -33,6 +33,7 @@ import io.netty.util.internal.OutOfDirectMemoryError import org.roaringbitmap.RoaringBitmap import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.ExecutorDeadException import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.internal.Logging @@ -979,8 +980,49 @@ final class ShuffleBlockFetcherIterator( log"${MDC(MAX_ATTEMPTS, maxAttemptsOnNettyOOM)} retries due to Netty OOM" logError(logMessage) errorMsg = logMessage.message + } else { + logInfo(s"Block $blockId fetch failed for mapIndex $mapIndex from $address") + } + // If the fetch has failed due to a dead executor, the block may be available elsewhere + // in the event of graceful storage decommissioning / block migration. + // Rather than failing and potentially retrying the whole stage, + // we can check for a new block location and then re-queue the fetch with updated BM. + // TODO: skip this if spark.storage.decommission.enabled = false + if (e.isInstanceOf[ExecutorDeadException]) { + val newBlocksByAddr = blockId match { + case ShuffleBlockId(shuffleId, _, reduceId) => + mapOutputTracker.unregisterShuffle(shuffleId) + mapOutputTracker.getMapSizesByExecutorId( + shuffleId, + mapIndex, + mapIndex + 1, + reduceId, + reduceId + 1) + .filter(_._1 != address) + case ShuffleBlockBatchId(shuffleId, _, startReduceId, endReduceId) => + mapOutputTracker.unregisterShuffle(shuffleId) + mapOutputTracker.getMapSizesByExecutorId( + shuffleId, + mapIndex, + mapIndex + 1, + startReduceId, + endReduceId) + .filter(_._1 != address) + case _ => + logInfo(s"Fetching block $blockId failed") + Iterator.empty + } + if (newBlocksByAddr.nonEmpty) { + logInfo(s"New addresses found for block $blockId and mapIndex $mapIndex") + fallbackFetch(newBlocksByAddr) + result = null + } else { + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + } + } + else { + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) } - throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) case DeferFetchRequestResult(request) => val address = request.address @@ -1206,6 +1248,7 @@ final class ShuffleBlockFetcherIterator( while (isRemoteBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() val remoteAddress = request.address + // TODO: check if remote address is dead, schedule to retry migration destination if (isRemoteAddressMaxedOut(remoteAddress, request)) { logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 211de2e8729e..5bf4a0476ec5 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -38,6 +38,7 @@ import org.mockito.stubbing.Answer import org.roaringbitmap.RoaringBitmap import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} +import org.apache.spark.ExecutorDeadException import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} @@ -88,6 +89,27 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { } } + /** Configures `transfer` (mock [[BlockTransferService]]) to simulate a removed Executor. */ + private def configureMockTransferDeadExecutor(data: Map[BlockId, ManagedBuffer]): Unit = { + var hasThrown = false + answerFetchBlocks { invocation => + val blocks = invocation.getArgument[Array[String]](3) + val listener = invocation.getArgument[BlockFetchingListener](4) + + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + if (!hasThrown) { + listener.onBlockFetchFailure(blockId, new ExecutorDeadException("dead :(")) + hasThrown = true + } + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) + } + } + } + } + /** Configures `transfer` (mock [[BlockTransferService]]) which mimics the Netty OOM issue. */ private def configureNettyOOMMockTransfer( data: Map[BlockId, ManagedBuffer], @@ -2076,4 +2098,41 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(iterator.next()._1 === ShuffleBlockId(0, 1, 0)) assert(!iterator.hasNext) } + + test("SPARK-52090: Update block location when encountering a deadExecutorException") { + val blockManager = createMockBlockManager() + + val remoteBmId1 = BlockManagerId("test-remote-client-1", "test-remote-host-1", 2) + val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host-2", 2) + + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + + configureMockTransferDeadExecutor(remoteBlocks) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (remoteBmId1, toBlockList(remoteBlocks.keys, 1L, 0)) + ) + + val migratedBlocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId2, toBlockList(remoteBlocks.keys, 1L, 0)) + ) + when(mapOutputTracker.getMapSizesByExecutorId(any(), any(), any(), any(), any())) + .thenReturn(migratedBlocksByAddress.iterator) + + val iterator = createShuffleBlockIteratorWithDefaults( + blocksByAddress = blocksByAddress + ) + + // fetch all blocks from the iterator + while (iterator.hasNext) { + val (blockId, inputStream) = iterator.next() + } + + verify (mapOutputTracker, times(1)).unregisterShuffle(any()) + verify ( + mapOutputTracker, times(1)).getMapSizesByExecutorId(any(), any(), any(), any(), any() + ) + } }