diff --git a/README.md b/README.md index a2f41e2..b6e0aee 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ import com.databricks.spark.sql.perf.tpcds.Tables // dsdgenDir is the location of dsdgen tool installed in your machines. val tables = new Tables(sqlContext, dsdgenDir, scaleFactor) // Generate data. -tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal) +tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal, orderByPartitionColumns) // Create metastore tables in a specified database for your data. // Once tables are created, the current database will be switched to the specified database. tables.createExternalTables(location, format, databaseName, overwrite) diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index 51b687b..702e8b1 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -103,17 +103,21 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend Table(name, partitionColumns, newFields:_*) } - def genData(location: String, format: String, overwrite: Boolean): Unit = { + def genData(location: String, format: String, overwrite: Boolean, orderByPartitionColumns: Boolean): Unit = { val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore - val writer = if (!partitionColumns.isEmpty) { - df.write + val writer = if (partitionColumns.nonEmpty) { + if (orderByPartitionColumns) { + df.orderBy(partitionColumns.map(col): _*).write + } else { + df.write + } } else { // If the table is not partitioned, coalesce the data to a single file. df.coalesce(1).write } writer.format(format).mode(mode) - if (!partitionColumns.isEmpty) { + if (partitionColumns.nonEmpty) { writer.partitionBy(partitionColumns : _*) } println(s"Generating table $name in database to $location with save mode $mode.") @@ -146,7 +150,8 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend format: String, overwrite: Boolean, partitionTables: Boolean, - useDoubleForDecimal: Boolean): Unit = { + useDoubleForDecimal: Boolean, + orderByPartitionColumns: Boolean): Unit = { val tablesToBeGenerated = if (partitionTables) { tables } else { @@ -161,7 +166,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend withSpecifiedDataType.foreach { table => val tableLocation = s"$location/${table.name}" - table.genData(tableLocation, format, overwrite) + table.genData(tableLocation, format, overwrite, orderByPartitionColumns) } }