diff --git a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala index 842780b..368ea5d 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala @@ -16,6 +16,9 @@ package com.databricks.spark.sql.perf +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.immutable.Stream import scala.sys.process._ import org.slf4j.LoggerFactory @@ -26,6 +29,62 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext, SaveMode} + +/** + * Using ProcessBuilder.lineStream produces a stream, that uses + * a LinkedBlockingQueue with a default capacity of Integer.MAX_VALUE. + * + * This causes OOM if the consumer cannot keep up with the producer. + * + * See scala.sys.process.ProcessBuilderImpl.lineStream + */ +object BlockingLineStream { + // See scala.sys.process.Streamed + private final class BlockingStreamed[T]( + val process: T => Unit, + val done: Int => Unit, + val stream: () => Stream[T] + ) + + // See scala.sys.process.Streamed + private object BlockingStreamed { + // scala.process.sys.Streamed uses default of Integer.MAX_VALUE, + // which causes OOMs if the consumer cannot keep up with producer. + val maxQueueSize = 65536 + + def apply[T](nonzeroException: Boolean): BlockingStreamed[T] = { + val q = new LinkedBlockingQueue[Either[Int, T]](maxQueueSize) + + def next(): Stream[T] = q.take match { + case Left(0) => Stream.empty + case Left(code) => + if (nonzeroException) scala.sys.error("Nonzero exit code: " + code) else Stream.empty + case Right(s) => Stream.cons(s, next()) + } + + new BlockingStreamed((s: T) => q put Right(s), code => q put Left(code), () => next()) + } + } + + // See scala.sys.process.ProcessImpl.Spawn + private object Spawn { + def apply(f: => Unit): Thread = apply(f, daemon = false) + def apply(f: => Unit, daemon: Boolean): Thread = { + val thread = new Thread() { override def run() = { f } } + thread.setDaemon(daemon) + thread.start() + thread + } + } + + def apply(command: Seq[String]): Stream[String] = { + val streamed = BlockingStreamed[String](true) + val process = command.run(BasicIO(false, streamed.process, None)) + Spawn(streamed.done(process.exitValue())) + streamed.stream() + } +} + trait DataGenerator extends Serializable { def generate( sparkContext: SparkContext, diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala index 29553c7..8243cd3 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDSTables.scala @@ -19,7 +19,7 @@ package com.databricks.spark.sql.perf.tpcds import scala.sys.process._ import com.databricks.spark.sql.perf -import com.databricks.spark.sql.perf.{DataGenerator, Table, Tables} +import com.databricks.spark.sql.perf.{BlockingLineStream, DataGenerator, Table, Tables} import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext @@ -44,7 +44,7 @@ class DSDGEN(dsdgenDir: String) extends DataGenerator { "bash", "-c", s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor -RNGSEED 100 $parallel") println(commands) - commands.lines + BlockingLineStream(commands) } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala b/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala index bbd35d0..6d78f13 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpch/TPCH.scala @@ -17,7 +17,7 @@ package com.databricks.spark.sql.perf.tpch import scala.sys.process._ -import com.databricks.spark.sql.perf.{Benchmark, DataGenerator, Table, Tables} +import com.databricks.spark.sql.perf.{Benchmark, BlockingLineStream, DataGenerator, Table, Tables} import com.databricks.spark.sql.perf.ExecutionMode.CollectResults import org.apache.commons.io.IOUtils @@ -54,7 +54,7 @@ class DBGEN(dbgenDir: String, params: Seq[String]) extends DataGenerator { "bash", "-c", s"cd $localToolsDir && ./dbgen -q $paramsString -T ${shortTableNames(name)} -s $scaleFactor $parallel") println(commands) - commands.lines + BlockingLineStream(commands) }.repartition(numPartitions) }