diff --git a/README.md b/README.md index 198cd63..520049c 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ import com.databricks.spark.sql.perf.tpcds.Tables // dsdgenDir is the location of dsdgen tool installed in your machines. val tables = new Tables(sqlContext, dsdgenDir, scaleFactor) // Generate data. -tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal, clusterByPartitionColumns) +tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal, clusterByPartitionColumns, filterOutNullPartitionValues) // Create metastore tables in a specified database for your data. // Once tables are created, the current database will be switched to the specified database. tables.createExternalTables(location, format, databaseName, overwrite) diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index 1749440..ee470f8 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -104,7 +104,12 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend Table(name, partitionColumns, newFields:_*) } - def genData(location: String, format: String, overwrite: Boolean, clusterByPartitionColumns: Boolean): Unit = { + def genData( + location: String, + format: String, + overwrite: Boolean, + clusterByPartitionColumns: Boolean, + filterOutNullPartitionValues: Boolean): Unit = { val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore val data = df @@ -117,6 +122,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend field.name }.mkString(",") val partitionColumnString = partitionColumns.mkString(",") + val predicates = if (filterOutNullPartitionValues) { + partitionColumns.map(col => s"$col IS NOT NULL").mkString("WHERE ", " AND ", "") + } else { + "" + } val query = s""" @@ -124,6 +134,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend | $columnString |FROM | $tempTableName + |$predicates |DISTRIBUTE BY | $partitionColumnString """.stripMargin @@ -174,7 +185,8 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend overwrite: Boolean, partitionTables: Boolean, useDoubleForDecimal: Boolean, - clusterByPartitionColumns: Boolean): Unit = { + clusterByPartitionColumns: Boolean, + filterOutNullPartitionValues: Boolean): Unit = { val tablesToBeGenerated = if (partitionTables) { tables } else { @@ -189,7 +201,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend withSpecifiedDataType.foreach { table => val tableLocation = s"$location/${table.name}" - table.genData(tableLocation, format, overwrite, clusterByPartitionColumns) + table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, filterOutNullPartitionValues) } }