[CELEBORN-2086] S3FlushTask and OssFlushTask should close ByteArrayInputStream to avoid resource leak

### What changes were proposed in this pull request?

`S3FlushTask` and `OssFlushTask` should close `ByteArrayInputStream` to avoid resource leak.

### Why are the changes needed?

`S3FlushTask` and `OssFlushTask` don't close `ByteArrayInputStream` at present, which may cause resource leak.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

CI.

Closes #3395 from SteNicholas/CELEBORN-2086.

Authored-by: SteNicholas <programgeek@163.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
SteNicholas 2025-07-29 17:19:18 +08:00 committed by Shuang
parent 4540b5772b
commit 392f6186df

View File

@ -17,12 +17,13 @@
package org.apache.celeborn.service.deploy.worker.storage
import java.io.ByteArrayInputStream
import java.io.{ByteArrayInputStream, Closeable, IOException}
import java.nio.channels.FileChannel
import io.netty.buffer.{ByteBufUtil, CompositeByteBuf}
import org.apache.hadoop.fs.Path
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.metrics.source.AbstractSource
import org.apache.celeborn.common.protocol.StorageInfo.Type
import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler
@ -65,20 +66,39 @@ private[worker] class LocalFlushTask(
}
}
abstract private[worker] class DfsFlushTask(
buffer: CompositeByteBuf,
notifier: FlushNotifier,
keepBuffer: Boolean,
source: AbstractSource) extends FlushTask(buffer, notifier, keepBuffer, source) with Logging {
def flush(stream: Closeable)(block: => Unit): Unit = {
try {
block
} finally {
try {
stream.close()
} catch {
case e: IOException => logWarning("Close flush dfs stream failed.", e)
}
}
}
}
private[worker] class HdfsFlushTask(
buffer: CompositeByteBuf,
val path: Path,
notifier: FlushNotifier,
keepBuffer: Boolean,
source: AbstractSource) extends FlushTask(buffer, notifier, keepBuffer, source) {
source: AbstractSource) extends DfsFlushTask(buffer, notifier, keepBuffer, source) {
override def flush(): Unit = {
val readableBytes = buffer.readableBytes()
val hadoopFs = StorageManager.hadoopFs.get(Type.HDFS)
val hdfsStream = hadoopFs.append(path, 256 * 1024)
hdfsStream.write(ByteBufUtil.getBytes(buffer))
hdfsStream.close()
source.incCounter(WorkerSource.HDFS_FLUSH_COUNT)
source.incCounter(WorkerSource.HDFS_FLUSH_SIZE, readableBytes)
flush(hdfsStream) {
hdfsStream.write(ByteBufUtil.getBytes(buffer))
source.incCounter(WorkerSource.HDFS_FLUSH_COUNT)
source.incCounter(WorkerSource.HDFS_FLUSH_SIZE, readableBytes)
}
}
}
@ -90,15 +110,17 @@ private[worker] class S3FlushTask(
s3MultipartUploader: MultipartUploadHandler,
partNumber: Int,
finalFlush: Boolean = false)
extends FlushTask(buffer, notifier, keepBuffer, source) {
extends DfsFlushTask(buffer, notifier, keepBuffer, source) {
override def flush(): Unit = {
val readableBytes = buffer.readableBytes()
val bytes = ByteBufUtil.getBytes(buffer)
val inputStream = new ByteArrayInputStream(bytes)
s3MultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.S3_FLUSH_COUNT)
source.incCounter(WorkerSource.S3_FLUSH_SIZE, readableBytes)
flush(inputStream) {
s3MultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.S3_FLUSH_COUNT)
source.incCounter(WorkerSource.S3_FLUSH_SIZE, readableBytes)
}
}
}
@ -110,14 +132,16 @@ private[worker] class OssFlushTask(
ossMultipartUploader: MultipartUploadHandler,
partNumber: Int,
finalFlush: Boolean = false)
extends FlushTask(buffer, notifier, keepBuffer, source) {
extends DfsFlushTask(buffer, notifier, keepBuffer, source) {
override def flush(): Unit = {
val readableBytes = buffer.readableBytes()
val bytes = ByteBufUtil.getBytes(buffer)
val inputStream = new ByteArrayInputStream(bytes)
ossMultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.OSS_FLUSH_COUNT)
source.incCounter(WorkerSource.OSS_FLUSH_SIZE, readableBytes)
flush(inputStream) {
ossMultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.OSS_FLUSH_COUNT)
source.incCounter(WorkerSource.OSS_FLUSH_SIZE, readableBytes)
}
}
}