diff --git a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala index 775dfd8..4d3657b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Tables.scala @@ -179,7 +179,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String, val data = df(format != "text", numPartitions) val tempTableName = s"${name}_text" - data.registerTempTable(tempTableName) + data.createOrReplaceTempView(tempTableName) val writer = if (partitionColumns.nonEmpty) { if (clusterByPartitionColumns) { @@ -211,9 +211,24 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String, data.write } } else { + // treat non-partitioned tables as "one partition" that we want to coalesce if (clusterByPartitionColumns) { - // treat non-partitioned tables as "one partition" that we want to coalesce - data.coalesce(1).write + // in case data has more than maxRecordsPerFile, split into multiple writers to improve datagen speed + // files will be truncated to maxRecordsPerFile value, so the final result will be the same + val numRows = data.count + val maxRecordPerFile = util.Try(sqlContext.getConf("spark.sql.files.maxRecordsPerFile").toInt).getOrElse(0) + + println(s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile") + log.info(s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile") + + if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) { + val numFiles = ((numRows)/maxRecordPerFile).ceil.toInt + println(s"Coalescing into $numFiles files") + log.info(s"Coalescing into $numFiles files") + data.coalesce(numFiles).write + } else { + data.coalesce(1).write + } } else { data.write } @@ -251,7 +266,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String, def createTemporaryTable(location: String, format: String): Unit = { println(s"Creating temporary table $name using data stored in $location.") log.info(s"Creating temporary table $name using data stored in $location.") - sqlContext.read.format(format).load(location).registerTempTable(name) + sqlContext.read.format(format).load(location).createOrReplaceTempView(name) } def analyzeTable(databaseName: String, analyzeColumns: Boolean = false): Unit = {