address comments.

This commit is contained in:
Yin Huai 2015-08-20 16:19:04 -07:00
parent 97093a45cd
commit 77fbe22b7b

View File

@ -36,6 +36,10 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
val schema = StructType(fields)
val partitions = if (partitionColumns.isEmpty) 1 else 100
def nonPartitioned: Table = {
Table(name, Nil, fields : _*)
}
def df = {
val generatedData = {
sparkContext.parallelize(1 to partitions, partitions).flatMap { i =>
@ -90,14 +94,19 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
def genData(location: String, format: String, overwrite: Boolean): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
val withFormatAndMode = df.write.format(format).mode(mode)
val withPartitionColumns = if (partitionColumns.isEmpty) {
withFormatAndMode
val writer = if (!partitionColumns.isEmpty) {
df.write
} else {
withFormatAndMode.partitionBy(partitionColumns : _*)
// If the table is not partitioned, coalesce the data to a single file.
df.coalesce(1).write
}
writer.format(format).mode(mode)
if (!partitionColumns.isEmpty) {
writer.partitionBy(partitionColumns : _*)
}
println(s"Generating table $name in database to $location with save mode $mode.")
logInfo(s"Generating table $name in database to $location with save mode $mode.")
withPartitionColumns.save(location)
writer.save(location)
}
def createExternalTable(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = {
@ -107,21 +116,27 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name")
}
if (!tableExists || overwrite) {
println(s"Creating external table $name in database $databaseName using data stored in $location.")
logInfo(s"Creating external table $name in database $databaseName using data stored in $location.")
sqlContext.createExternalTable(qualifiedTableName, location, format)
}
}
def createTemporaryTable(location: String, format: String): Unit = {
println(s"Creating temporary table $name using data stored in $location.")
logInfo(s"Creating temporary table $name using data stored in $location.")
sqlContext.read.format(format).load(location).registerTempTable(name)
}
}
def genData(location: String, format: String, overwrite: Boolean): Unit = {
tables.foreach { table =>
val tableLocation =
location + File.separator + format + File.separator + "sf" + scaleFactor + File.separator + table.name
def genData(location: String, format: String, overwrite: Boolean, partitionTables: Boolean): Unit = {
val tablesToBeGenerated = if (partitionTables) {
tables
} else {
tables.map(_.nonPartitioned)
}
tablesToBeGenerated.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.genData(tableLocation, format, overwrite)
}
}
@ -129,18 +144,17 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = {
sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
tables.foreach { table =>
val tableLocation =
location + File.separator + format + File.separator + "sf" + scaleFactor + File.separator + table.name
val tableLocation = s"$location/${table.name}"
table.createExternalTable(tableLocation, format, databaseName, overwrite)
}
sqlContext.sql(s"USE $databaseName")
println(s"The current database has been set to $databaseName.")
logInfo(s"The current database has been set to $databaseName.")
}
def createTemporaryTables(location: String, format: String): Unit = {
tables.foreach { table =>
val tableLocation =
location + File.separator + format + File.separator + "sf" + scaleFactor + File.separator + table.name
val tableLocation = s"$location/${table.name}"
table.createTemporaryTable(tableLocation, format)
}
}