diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index 3e61a9d97..46c58d0da 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -24,7 +24,6 @@ import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray} import scala.collection.JavaConverters._ -import scala.collection.mutable import com.google.common.annotations.VisibleForTesting import io.netty.util.HashedWheelTimer @@ -47,7 +46,7 @@ import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.rpc._ import org.apache.celeborn.common.util.{ShutdownHookManager, ThreadUtils, Utils} import org.apache.celeborn.server.common.{HttpService, Service} -import org.apache.celeborn.service.deploy.worker.storage.{PartitionFilesSorter, StorageManager} +import org.apache.celeborn.service.deploy.worker.storage.{FileWriter, PartitionFilesSorter, StorageManager} private[celeborn] class Worker( override val conf: CelebornConf, @@ -409,19 +408,22 @@ private[celeborn] class Worker( // If worker register still failed after retry, throw exception to stop worker process throw new CelebornException("Register worker failed.", exception) } - - private def cleanup(expiredShuffleKeys: JHashSet[String]): Unit = synchronized { + @VisibleForTesting + def cleanup(expiredShuffleKeys: JHashSet[String]): Unit = synchronized { expiredShuffleKeys.asScala.foreach { shuffleKey => partitionLocationInfo.getAllMasterLocations(shuffleKey).asScala.foreach { partition => val fileWriter = partition.asInstanceOf[WorkingPartition].getFileWriter fileWriter.destroy(new IOException( s"Destroy FileWriter ${fileWriter} caused by shuffle ${shuffleKey} expired.")) + removeExpiredWorkingDirWriters(fileWriter) } partitionLocationInfo.getAllSlaveLocations(shuffleKey).asScala.foreach { partition => val fileWriter = partition.asInstanceOf[WorkingPartition].getFileWriter fileWriter.destroy(new IOException( s"Destroy FileWriter ${fileWriter} caused by shuffle ${shuffleKey} expired.")) + removeExpiredWorkingDirWriters(fileWriter) } + partitionLocationInfo.removeMasterPartitions(shuffleKey) partitionLocationInfo.removeSlavePartitions(shuffleKey) shufflePartitionType.remove(shuffleKey) @@ -434,6 +436,16 @@ private[celeborn] class Worker( storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys) } + @VisibleForTesting + def removeExpiredWorkingDirWriters(fileWriter: FileWriter): Unit = { + // filepath is dir/appId/shuffleId/filename + val dir = fileWriter.getFile.getParentFile.getParentFile.getParentFile + storageManager.workingDirWriters.asScala.get(dir).map(f => + f.synchronized { + f.remove(fileWriter) + }) + } + override def getWorkerInfo: String = workerInfo.toString() override def getThreadDump: String = Utils.getThreadDump() diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala new file mode 100644 index 000000000..0939f168b --- /dev/null +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/WorkerSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.service.deploy.worker.storage + +import java.io.File +import java.util +import java.util.{HashSet => JHashSet} + +import org.junit.Assert +import org.mockito.MockitoSugar._ +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.meta.FileInfo +import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType} +import org.apache.celeborn.service.deploy.worker.{Worker, WorkerArguments, WorkerSource, WorkingPartition} + +class WorkerSuite extends AnyFunSuite { + val conf = new CelebornConf() + val workerArgs = new WorkerArguments(Array(), conf) + test("clean up") { + + val worker = new Worker(conf, workerArgs) + val expiredShuffleKeys = new JHashSet[String]() + val shuffleKey1 = "s1" + val shuffleKey2 = "s2" + expiredShuffleKeys.add(shuffleKey1) + expiredShuffleKeys.add(shuffleKey2) + val pl1 = new PartitionLocation(0, 0, "12", 0, 0, 0, 0, PartitionLocation.Mode.MASTER) + val pl2 = new PartitionLocation(0, 0, "12", 0, 0, 0, 0, PartitionLocation.Mode.SLAVE) + + val dir1 = new File("/tmp/work1") + val dir2 = new File("/tmp/work2") + val dir3 = new File("/tmp/work3") + + val filePath1 = new File(dir1, "/1/1") + filePath1.mkdirs() + val file1 = new File(filePath1, "/1") + file1.createNewFile() + val filePath2 = new File(dir2, "/2/2") + filePath2.mkdirs() + val file2 = new File(filePath2, "/2") + file2.createNewFile() + + val filePath3 = new File(dir3, "/3/3") + filePath3.mkdirs() + val file3 = new File(filePath3, "/3") + file3.createNewFile() + + val fw1 = new ReducePartitionFileWriter( + new FileInfo(file1.getAbsolutePath, null, null, PartitionType.REDUCE), + mock[Flusher], + new WorkerSource(conf), + conf, + mock[DeviceMonitor], + 0, + null, + false) + fw1.registerDestroyHook(new util.ArrayList(util.Arrays.asList(fw1))) + val fw2 = new ReducePartitionFileWriter( + new FileInfo(file2.getAbsolutePath, null, null, PartitionType.REDUCE), + mock[Flusher], + new WorkerSource(conf), + conf, + mock[DeviceMonitor], + 0, + null, + false) + fw2.registerDestroyHook(new util.ArrayList(util.Arrays.asList(fw2))) + val fw3 = new ReducePartitionFileWriter( + new FileInfo(file3.getAbsolutePath, null, null, PartitionType.REDUCE), + mock[Flusher], + new WorkerSource(conf), + conf, + mock[DeviceMonitor], + 0, + null, + false) + fw3.registerDestroyHook(new util.ArrayList(util.Arrays.asList(fw3))) + + val wl1 = new WorkingPartition(pl1, fw1) + val wl2 = new WorkingPartition(pl2, fw2) + val wl3 = new WorkingPartition(pl2, fw3) + worker.partitionLocationInfo.addMasterPartitions(shuffleKey1, util.Arrays.asList(wl1)) + worker.partitionLocationInfo.addMasterPartitions(shuffleKey1, util.Arrays.asList(wl3)) + worker.partitionLocationInfo.addSlavePartitions(shuffleKey1, util.Arrays.asList(wl2)) + + val fws1 = new util.ArrayList[FileWriter]() + fws1.add(fw1) + val fws2 = new util.ArrayList[FileWriter]() + fws2.add(fw2) + worker.storageManager.workingDirWriters.put(dir1, fws1) + worker.storageManager.workingDirWriters.put(dir2, fws2) + Assert.assertEquals(1, worker.storageManager.workingDirWriters.get(dir1).size()) + Assert.assertEquals(1, worker.storageManager.workingDirWriters.get(dir2).size()) + worker.cleanup(expiredShuffleKeys) + Assert.assertEquals(0, worker.storageManager.workingDirWriters.get(dir1).size()) + Assert.assertEquals(0, worker.storageManager.workingDirWriters.get(dir2).size()) + + deleteFile(dir1) + deleteFile(dir2) + deleteFile(dir3) + } + def deleteFile(dir: File): Unit = { + val files = dir.listFiles(); + if (files != null) { + files.foreach(file => { + if (file.isFile()) { + file.delete(); + } else { + deleteFile(file); + } + }) + dir.delete(); + } + } +}