diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index ff4d645..2808ece 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -123,6 +123,10 @@ abstract class Benchmark( @volatile var failures = 0 @volatile var startTime = 0L + /** An optional log collection task that will run after the experiment. */ + @volatile var logCollection: () => Unit = () => {} + + def cartesianProduct[T](xss: List[List[T]]): List[List[T]] = xss match { case Nil => List(Nil) case h :: t => for(xh <- h; xt <- cartesianProduct(t)) yield xh :: xt @@ -203,17 +207,21 @@ abstract class Benchmark( } catch { case e: Throwable => currentMessages += s"Failed to write data: $e" } + + logCollection() } - def scheduleCpuCollection(fs: FS) = resultsFuture.onComplete { _ => - currentMessages += s"Begining CPU log collection" - try { - val location = cpu.collectLogs(sqlContext, fs, timestamp) - currentMessages += s"cpu results recorded to $location" - } catch { - case e: Throwable => - currentMessages += s"Error collecting logs: $e" - throw e + def scheduleCpuCollection(fs: FS) = { + logCollection = () => { + currentMessages += s"Begining CPU log collection" + try { + val location = cpu.collectLogs(sqlContext, fs, timestamp) + currentMessages += s"cpu results recorded to $location" + } catch { + case e: Throwable => + currentMessages += s"Error collecting logs: $e" + throw e + } } }