diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index 40cfcb136..b6b729dd0 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -82,15 +82,22 @@ class PushDataHandler extends BaseMessageHandler with Logging { client, pushData, pushData.requestId, - () => - handlePushData( - pushData, - new SimpleRpcResponseCallback( - Type.PUSH_DATA, - client, - pushData.requestId, - pushData.shuffleKey, - pushData.partitionUniqueId))) + () => { + val callback = new SimpleRpcResponseCallback( + Type.PUSH_DATA, + client, + pushData.requestId, + pushData.shuffleKey, + pushData.partitionUniqueId) + shufflePartitionType.getOrDefault(pushData.shuffleKey, PartitionType.REDUCE) match { + case PartitionType.REDUCE => handlePushData( + pushData, + callback) + case PartitionType.MAP => handleMapPartitionPushData( + pushData, + callback) + } + }) case pushMergedData: PushMergedData => { handleCore( client, @@ -532,6 +539,90 @@ class PushDataHandler extends BaseMessageHandler with Logging { } } + def handleMapPartitionPushData(pushData: PushData, callback: RpcResponseCallback): Unit = { + val shuffleKey = pushData.shuffleKey + val mode = PartitionLocation.getMode(pushData.mode) + val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf + val isMaster = mode == PartitionLocation.Mode.MASTER + + val key = s"${pushData.requestId}" + if (isMaster) { + workerSource.startTimer(WorkerSource.MasterPushDataTime, key) + } else { + workerSource.startTimer(WorkerSource.SlavePushDataTime, key) + } + + // find FileWriter responsible for the data + val location = + if (isMaster) { + partitionLocationInfo.getMasterLocation(shuffleKey, pushData.partitionUniqueId) + } else { + partitionLocationInfo.getSlaveLocation(shuffleKey, pushData.partitionUniqueId) + } + + val wrappedCallback = + new WrappedRpcResponseCallback( + pushData.`type`(), + isMaster, + pushData.requestId, + null, + location, + if (isMaster) WorkerSource.MasterPushDataTime else WorkerSource.SlavePushDataTime, + callback) + + if (checkLocationNull( + pushData.`type`(), + shuffleKey, + pushData.partitionUniqueId, + null, + location, + callback, + wrappedCallback)) return + + // During worker shutdown, worker will return HARD_SPLIT for all existed partition. + // This should before return exception to make current push data can revive and retry. + if (shutdown.get()) { + logInfo(s"Push data return HARD_SPLIT for shuffle $shuffleKey since worker shutdown.") + callback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue))) + return + } + + val fileWriter = + getFileWriterAndCheck(pushData.`type`(), location, isMaster, callback) match { + case (true, _) => return + case (false, f: FileWriter) => f + } + + // for mappartition we will not check whether disk full or split partition + + fileWriter.incrementPendingWrites() + + // for master, send data to slave + if (location.getPeer != null && isMaster) { + // to do + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) + } else { + wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte]())) + } + + try { + fileWriter.write(body) + } catch { + case e: AlreadyClosedException => + fileWriter.decrementPendingWrites() + val (mapId, attemptId) = getMapAttempt(body) + val endedAttempt = + if (shuffleMapperAttempts.containsKey(shuffleKey)) { + shuffleMapperAttempts.get(shuffleKey).get(mapId) + } else -1 + // TODO just info log for ended attempt + logWarning(s"Append data failed for task(shuffle $shuffleKey, map $mapId, attempt" + + s" $attemptId), caused by ${e.getMessage}") + case e: Exception => + logError("Exception encountered when write.", e) + } + } + private def handleRpcRequest(client: TransportClient, rpcRequest: RpcRequest): Unit = { val msg = Message.decode(rpcRequest.body().nioByteBuffer()) val requestId = rpcRequest.requestId @@ -701,7 +792,7 @@ class PushDataHandler extends BaseMessageHandler with Logging { callback: RpcResponseCallback, wrappedCallback: RpcResponseCallback): Boolean = { if (location == null) { - val (mapId, attemptId) = getMapAttempt(body, shuffleKey, partitionUniqueId) + val (mapId, attemptId) = getMapAttempt(partitionUniqueId) if (shuffleMapperAttempts.containsKey(shuffleKey) && -1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) { // partition data has already been committed @@ -792,15 +883,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { } private def getMapAttempt( - body: ByteBuf, - shuffleKey: String, partitionUniqueId: String): (Int, Int) = { - shufflePartitionType.get(shuffleKey) match { - case PartitionType.MAP => { - val id = partitionUniqueId.split("-")(0).toInt - (PackedPartitionId.getRawPartitionId(id), PackedPartitionId.getAttemptId(id)) - } - case PartitionType.REDUCE => getMapAttempt(body) - } + val id = partitionUniqueId.split("-")(0).toInt + (PackedPartitionId.getRawPartitionId(id), PackedPartitionId.getAttemptId(id)) } }