From c087b68a5ca0667395095acb18f337f38a93d5ef Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 May 2016 10:40:51 -0700 Subject: [PATCH] make number of partitions configurable --- .../databricks/spark/sql/perf/tpcds/Tables.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index e6c75ec..414e04d 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -34,7 +34,6 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend case class Table(name: String, partitionColumns: Seq[String], fields: StructField*) { val schema = StructType(fields) - val partitions = if (partitionColumns.isEmpty) 1 else 100 def nonPartitioned: Table = { Table(name, Nil, fields : _*) @@ -44,7 +43,8 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend * If convertToSchema is true, the data from generator will be parsed into columns and * converted to `schema`. Otherwise, it just outputs the raw data (as a single STRING column). */ - def df(convertToSchema: Boolean) = { + def df(convertToSchema: Boolean, numPartition: Int) = { + val partitions = if (partitionColumns.isEmpty) 1 else numPartition val generatedData = { sparkContext.parallelize(1 to partitions, partitions).flatMap { i => val localToolsDir = if (new java.io.File(dsdgen).exists) { @@ -121,10 +121,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend format: String, overwrite: Boolean, clusterByPartitionColumns: Boolean, - filterOutNullPartitionValues: Boolean): Unit = { + filterOutNullPartitionValues: Boolean, + numPartitions: Int): Unit = { val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore - val data = df(format != "text") + val data = df(format != "text", numPartitions) val tempTableName = s"${name}_text" data.registerTempTable(tempTableName) @@ -199,7 +200,8 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend useDoubleForDecimal: Boolean, clusterByPartitionColumns: Boolean, filterOutNullPartitionValues: Boolean, - tableFilter: String = ""): Unit = { + tableFilter: String = "", + numPartitions: Int = 100): Unit = { var tablesToBeGenerated = if (partitionTables) { tables } else { @@ -222,7 +224,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend withSpecifiedDataType.foreach { table => val tableLocation = s"$location/${table.name}" table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, - filterOutNullPartitionValues) + filterOutNullPartitionValues, numPartitions) } }