diff --git a/build.sbt b/build.sbt index d2b5770d..32e75013 100644 --- a/build.sbt +++ b/build.sbt @@ -94,6 +94,7 @@ lazy val migrator = (project in file("migrator")).settings( lazy val tests = project.in(file("tests")).settings( libraryDependencies ++= Seq( "com.amazonaws" % "aws-java-sdk-dynamodb" % awsSdkVersion, + "org.apache.spark" %% "spark-sql" % sparkVersion, "org.apache.cassandra" % "java-driver-query-builder" % "4.18.0", "com.github.mjakubowski84" %% "parquet4s-core" % "1.9.4", "org.apache.hadoop" % "hadoop-client" % "2.9.2", diff --git a/migrator/src/main/scala/com/scylladb/migrator/DynamoUtils.scala b/migrator/src/main/scala/com/scylladb/migrator/DynamoUtils.scala index aaee5158..04f798ac 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/DynamoUtils.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/DynamoUtils.scala @@ -186,7 +186,7 @@ object DynamoUtils { jobConf.set( "mapred.output.format.class", "org.apache.hadoop.dynamodb.write.DynamoDBOutputFormat") - jobConf.set("mapred.input.format.class", "org.apache.hadoop.dynamodb.read.DynamoDBInputFormat") + jobConf.set("mapred.input.format.class", "com.scylladb.migrator.alternator.DynamoDBInputFormat") } /** diff --git a/migrator/src/main/scala/com/scylladb/migrator/alternator/DynamoDBInputFormat.scala b/migrator/src/main/scala/com/scylladb/migrator/alternator/DynamoDBInputFormat.scala new file mode 100644 index 00000000..8d20c323 --- /dev/null +++ b/migrator/src/main/scala/com/scylladb/migrator/alternator/DynamoDBInputFormat.scala @@ -0,0 +1,86 @@ +package com.scylladb.migrator.alternator +import org.apache.commons.logging.LogFactory +import org.apache.hadoop.dynamodb.DynamoDBConstants +import org.apache.hadoop.mapred.{ InputSplit, JobConf } + +/** + * Specializes the split strategy: + * - do not bound the maximum number of partitions by the available memory per node + * - use as many partitions as the number of scan segments + * - by default, create segments that split the data into 128 MB chunks + */ +class DynamoDBInputFormat extends org.apache.hadoop.dynamodb.read.DynamoDBInputFormat { + + private val log = LogFactory.getLog(classOf[DynamoDBInputFormat]) + + override def getSplits(conf: JobConf, desiredSplits: Int): Array[InputSplit] = { + val readPercentage = conf.getDouble( + DynamoDBConstants.THROUGHPUT_READ_PERCENT, + DynamoDBConstants.DEFAULT_THROUGHPUT_PERCENTAGE.toDouble) + if (readPercentage <= 0) { + sys.error(s"Invalid read percentage: ${readPercentage}") + } + log.info(s"Read percentage: ${readPercentage}") + val maxReadThroughputAllocated = conf.getInt(DynamoDBConstants.READ_THROUGHPUT, 1) + val maxWriteThroughputAllocated = conf.getInt(DynamoDBConstants.WRITE_THROUGHPUT, 1) + if (maxReadThroughputAllocated < 1.0) { + sys.error( + s"Read throughput should not be less than 1. Read throughput percent: ${maxReadThroughputAllocated}") + } + + val configuredReadThroughput = + math.max(math.floor(maxReadThroughputAllocated * readPercentage).intValue(), 1) + + val tableSizeBytes = conf.getLong(DynamoDBConstants.TABLE_SIZE_BYTES, 1) + val numSegments = + getNumSegments(maxReadThroughputAllocated, maxWriteThroughputAllocated, tableSizeBytes, conf) + + val numMappers = getNumMappers(numSegments, configuredReadThroughput, conf) + + log.info(s"Using ${numSegments} segments across ${numMappers} mappers") + + getSplitGenerator().generateSplits(numMappers, numSegments, conf) + } + + override def getNumSegments(tableNormalizedReadThroughput: Int, + tableNormalizedWriteThroughput: Int, + currentTableSizeBytes: Long, + conf: JobConf): Int = { + // Use configured scan segment if provided + val configuredScanSegment = conf.getInt(DynamoDBConstants.SCAN_SEGMENTS, -1) + if (configuredScanSegment > 0) { + val numSegments = + math.max( + math.min(configuredScanSegment, DynamoDBConstants.MAX_SCAN_SEGMENTS), + DynamoDBConstants.MIN_SCAN_SEGMENTS + ) + log.info( + s"Using number of segments configured using ${DynamoDBConstants.SCAN_SEGMENTS}: ${numSegments}") + numSegments + } else { + // split into segments of at most 100 MB each (note: upstream implementation splits into 1 GB segments) + val numSegmentsForSize = { + val bytesPerSegment = 100 * 1024 * 1024 + (currentTableSizeBytes / bytesPerSegment).ceil.intValue() + } + log.info(s"Would use ${numSegmentsForSize} segments for size") + + val numSegmentsForThroughput = + (tableNormalizedReadThroughput / DynamoDBConstants.MIN_IO_PER_SEGMENT).intValue() + log.info(s"Would use ${numSegmentsForThroughput} segments for throughput") + + // Take the smallest and fit to bounds + val numSegments = + math.max( + math.min( + math.min(numSegmentsForSize, numSegmentsForThroughput), + DynamoDBConstants.MAX_SCAN_SEGMENTS + ), + DynamoDBConstants.MIN_SCAN_SEGMENTS + ) + log.info(s"Using computed number of segments: ${numSegments}") + numSegments + } + } + +} diff --git a/migrator/src/main/scala/com/scylladb/migrator/readers/DynamoDB.scala b/migrator/src/main/scala/com/scylladb/migrator/readers/DynamoDB.scala index 7dc33b5d..ab583c1b 100644 --- a/migrator/src/main/scala/com/scylladb/migrator/readers/DynamoDB.scala +++ b/migrator/src/main/scala/com/scylladb/migrator/readers/DynamoDB.scala @@ -3,9 +3,9 @@ package com.scylladb.migrator.readers import com.amazonaws.services.dynamodbv2.model.TableDescription import com.scylladb.migrator.DynamoUtils import com.scylladb.migrator.DynamoUtils.{ setDynamoDBJobConf, setOptionalConf } +import com.scylladb.migrator.alternator.DynamoDBInputFormat import com.scylladb.migrator.config.{ AWSCredentials, DynamoDBEndpoint, SourceSettings } import org.apache.hadoop.dynamodb.{ DynamoDBConstants, DynamoDBItemWritable } -import org.apache.hadoop.dynamodb.read.DynamoDBInputFormat import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.spark.rdd.RDD @@ -41,10 +41,46 @@ object DynamoDB { maxMapTasks: Option[Int], readThroughput: Option[Int], throughputReadPercent: Option[Float]): (RDD[(Text, DynamoDBItemWritable)], TableDescription) = { - val description = DynamoUtils + + val tableDescription = DynamoUtils .buildDynamoClient(endpoint, credentials.map(_.toAWSCredentialsProvider), region) .describeTable(table) .getTable + + val jobConf = + makeJobConf( + spark, + endpoint, + credentials, + region, + table, + scanSegments, + maxMapTasks, + readThroughput, + throughputReadPercent, + tableDescription) + + val rdd = + spark.sparkContext.hadoopRDD( + jobConf, + classOf[DynamoDBInputFormat], + classOf[Text], + classOf[DynamoDBItemWritable]) + (rdd, tableDescription) + } + + private[migrator] def makeJobConf( + spark: SparkSession, + endpoint: Option[DynamoDBEndpoint], + credentials: Option[AWSCredentials], + region: Option[String], + table: String, + scanSegments: Option[Int], + maxMapTasks: Option[Int], + readThroughput: Option[Int], + throughputReadPercent: Option[Float], + description: TableDescription + ): JobConf = { val maybeItemCount = Option(description.getItemCount).map(_.toLong) val maybeAvgItemSize = for { @@ -66,6 +102,7 @@ object DynamoDB { jobConf.set(DynamoDBConstants.INPUT_TABLE_NAME, table) setOptionalConf(jobConf, DynamoDBConstants.ITEM_COUNT, maybeItemCount.map(_.toString)) setOptionalConf(jobConf, DynamoDBConstants.AVG_ITEM_SIZE, maybeAvgItemSize.map(_.toString)) + setOptionalConf(jobConf, DynamoDBConstants.TABLE_SIZE_BYTES, Option(description.getTableSizeBytes).map(_.toString)) jobConf.set( DynamoDBConstants.READ_THROUGHPUT, readThroughput @@ -76,13 +113,7 @@ object DynamoDB { DynamoDBConstants.THROUGHPUT_READ_PERCENT, throughputReadPercent.map(_.toString)) - val rdd = - spark.sparkContext.hadoopRDD( - jobConf, - classOf[DynamoDBInputFormat], - classOf[Text], - classOf[DynamoDBItemWritable]) - (rdd, description) + jobConf } } diff --git a/tests/src/test/scala/com/scylladb/migrator/alternator/DynamoDBInputFormatTest.scala b/tests/src/test/scala/com/scylladb/migrator/alternator/DynamoDBInputFormatTest.scala new file mode 100644 index 00000000..4869cc5d --- /dev/null +++ b/tests/src/test/scala/com/scylladb/migrator/alternator/DynamoDBInputFormatTest.scala @@ -0,0 +1,79 @@ +package com.scylladb.migrator.alternator + +import com.amazonaws.services.dynamodbv2.model.{BillingMode, BillingModeSummary, ProvisionedThroughputDescription, TableDescription} +import com.scylladb.migrator.readers.DynamoDB +import org.apache.spark.sql.SparkSession + +class DynamoDBInputFormatTest extends munit.FunSuite { + + val spark = SparkSession.builder().appName("test").master("local[*]").getOrCreate() + + val GB: Long = 1024 * 1024 * 1024 + + test("no configured scanSegments in on-demand billing mode") { + checkPartitions(10)(tableSizeBytes = 1 * GB, tableProvisionedThroughput = None) + } + + test("no configured scanSegments in on-demand billing mode and table size is 100 GB") { + checkPartitions(1024)(tableSizeBytes = 100 * GB, tableProvisionedThroughput = None) + } + + test("no configured scanSegments in provisioned billing mode") { + checkPartitions(10)(tableSizeBytes = 1 * GB, tableProvisionedThroughput = Some((25, 25))) + } + + test("scanSegments = 42") { + checkPartitions(42)(configuredScanSegments = Some(42)) + } + + test("scanSegements = 42 and maxMapTasks = 10") { + checkPartitions(10)(configuredScanSegments = Some(42), configuredMaxMapTasks = Some(10)) + } + + def checkPartitions(expectedPartitions: Int)( + tableSizeBytes: Long = 0L, + tableProvisionedThroughput: Option[(Int, Int)] = None, + configuredScanSegments: Option[Int] = None, + configuredMaxMapTasks: Option[Int] = None, + configuredReadThroughput: Option[Int] = None, + configuredThroughputReadPercent: Option[Float] = None + ): Unit = { + val tableDescription = + new TableDescription() + .withTableName("DummyTable") + .withTableSizeBytes(tableSizeBytes) + tableProvisionedThroughput match { + case Some((rcu, wcu)) => + tableDescription.withProvisionedThroughput( + new ProvisionedThroughputDescription() + .withReadCapacityUnits(rcu) + .withWriteCapacityUnits(wcu) + ) + case None => + tableDescription.withProvisionedThroughput(new ProvisionedThroughputDescription()) + .withBillingModeSummary(new BillingModeSummary().withBillingMode(BillingMode.PAY_PER_REQUEST)) + } + + val jobConf = DynamoDB.makeJobConf( + spark = spark, + endpoint = None, + credentials = None, + region = None, + table = "DummyTable", + scanSegments = configuredScanSegments, + maxMapTasks = configuredMaxMapTasks, + readThroughput = configuredReadThroughput, + throughputReadPercent = configuredThroughputReadPercent, + description = tableDescription + ) + val splits = new DynamoDBInputFormat().getSplits(jobConf, 1) + + val partitions = splits.length + assertEquals(partitions, expectedPartitions) + } + + override def afterAll(): Unit = { + spark.stop() + } + +}