Skip to content

Commit

Permalink
Replace our fork of spark-kinesis with our adapted copy of the releva…
Browse files Browse the repository at this point in the history
…nt classes.

We adapted the `KinesisReceiver` and its related classes to work with DynamoDB Streams, and we renamed it into `KinesisDynamoDBReceiver`. These classes are based on the code from the original `spark-kinesis-asl` module with some slight modifications based on the following resources:

- https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Streams.KCLAdapter.Walkthrough.CompleteProgram.html
- https://medium.com/@ravi72munde/using-spark-streaming-with-dynamodb-d325b9a73c79
- and our previous fork implementation

As a result, instead of maintaining a complete fork of `spark-kinesis-asl`, we only maintain a copy of the relevant classes, which should result in much faster build times (especially in the CI).

It is still not possible to test the streaming feature locally (thus not in the CI either), see scylladb#113. These changes were tested with my actual AWS account.

Fixes scylladb#119
  • Loading branch information
julienrf committed Apr 26, 2024
1 parent af37817 commit a76cb48
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 77 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,3 @@
path = spark-cassandra-connector
url = https://github.com/scylladb/spark-cassandra-connector
branch = feature/track-token-ranges
[submodule "spark-kinesis"]
path = spark-kinesis
url = https://github.com/scylladb/spark-kinesis
44 changes: 29 additions & 15 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import sbt.librarymanagement.InclExclRule

val awsSdkVersion = "1.11.728"
val sparkVersion = "2.4.4"
val dynamodbStreamsKinesisAdapterVersion = "1.5.2"

inThisBuild(
List(
Expand All @@ -11,9 +12,12 @@ inThisBuild(
)
)

// Adaptation of spark-streaming-kinesis-asl to work with DynamoDB Streams
// Augmentation of spark-streaming-kinesis-asl to also work with DynamoDB Streams
lazy val `spark-kinesis-dynamodb` = project.in(file("spark-kinesis-dynamodb")).settings(
libraryDependencies += "org.apache.spark" %% "spark-streaming-kinesis-asl" % sparkVersion,
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-streaming-kinesis-asl" % sparkVersion,
"com.amazonaws" % "dynamodb-streams-kinesis-adapter" % dynamodbStreamsKinesisAdapterVersion
)
)

lazy val migrator = (project in file("migrator")).settings(
Expand All @@ -28,14 +32,14 @@ lazy val migrator = (project in file("migrator")).settings(
"-XX:+CMSClassUnloadingEnabled"),
scalacOptions ++= Seq("-deprecation", "-unchecked", "-Ypartial-unification"),
Test / parallelExecution := false,
fork := true,
scalafmtOnCompile := true,
fork := true,
scalafmtOnCompile := true,
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-streaming" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"com.amazonaws" % "aws-java-sdk-sts" % awsSdkVersion,
"com.amazonaws" % "aws-java-sdk-dynamodb" % awsSdkVersion,
("com.amazonaws" % "dynamodb-streams-kinesis-adapter" % "1.5.2")
("com.amazonaws" % "dynamodb-streams-kinesis-adapter" % dynamodbStreamsKinesisAdapterVersion)
.excludeAll(InclExclRule("com.fasterxml.jackson.core")),
"com.amazon.emr" % "emr-dynamodb-hadoop" % "4.16.0",
"io.circe" %% "circe-yaml" % "0.10.1",
Expand All @@ -45,11 +49,21 @@ lazy val migrator = (project in file("migrator")).settings(
ShadeRule.rename("org.yaml.snakeyaml.**" -> "com.scylladb.shaded.@1").inAll
),
assembly / assemblyMergeStrategy := {
case PathList("org", "joda", "time", _ @_*) => MergeStrategy.first
// Handle conflicts between our own library dependencies and those that are bundled into
// the spark-cassandra-connector fat-jar
case PathList("com", "codahale", "metrics", _ @_*) => MergeStrategy.first
case PathList("digesterRules.xml") => MergeStrategy.first
case PathList("org", "aopalliance", _ @_*) => MergeStrategy.first
case PathList("org", "apache", "commons", "collections", _ @_*) => MergeStrategy.first
case PathList("org", "apache", "commons", "configuration", _ @_*) => MergeStrategy.first
case PathList("org", "apache", "commons", "logging", _ @_*) => MergeStrategy.first
case PathList("com", "fasterxml", "jackson", "annotation", _ @_*) => MergeStrategy.first
case PathList("com", "fasterxml", "jackson", "core", _ @_*) => MergeStrategy.first
case PathList("com", "fasterxml", "jackson", "databind", _ @_*) => MergeStrategy.first
case PathList("org", "apache", "spark", _ @_*) => MergeStrategy.first
case PathList("org", "slf4j", _ @_*) => MergeStrategy.first
case PathList("properties.dtd") => MergeStrategy.first
case PathList("PropertyList-1.0.dtd") => MergeStrategy.first
// Other conflicts
case PathList("javax", "inject", _ @_*) => MergeStrategy.first
case PathList("org", "apache", "hadoop", _ @_*) => MergeStrategy.first
case x =>
val oldStrategy = (assembly / assemblyMergeStrategy).value
oldStrategy(x)
Expand Down Expand Up @@ -77,12 +91,12 @@ lazy val migrator = (project in file("migrator")).settings(

lazy val tests = project.in(file("tests")).settings(
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk-dynamodb" % awsSdkVersion,
"org.apache.cassandra" % "java-driver-query-builder" % "4.18.0",
"com.github.mjakubowski84" %% "parquet4s-core" % "1.9.4",
"org.apache.hadoop" % "hadoop-client" % "2.9.2",
"org.scalameta" %% "munit" % "0.7.29",
"org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0"
"com.amazonaws" % "aws-java-sdk-dynamodb" % awsSdkVersion,
"org.apache.cassandra" % "java-driver-query-builder" % "4.18.0",
"com.github.mjakubowski84" %% "parquet4s-core" % "1.9.4",
"org.apache.hadoop" % "hadoop-client" % "2.9.2",
"org.scalameta" %% "munit" % "0.7.29",
"org.scala-lang.modules" %% "scala-collection-compat" % "2.11.0"
),
Test / parallelExecution := false
).dependsOn(migrator)
Expand Down
4 changes: 0 additions & 4 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@ trap "rm -rf $TMPDIR" EXIT
pushd spark-cassandra-connector
sbt -Djava.io.tmpdir="$TMPDIR" ++2.11.12 assembly
popd
pushd spark-kinesis
sbt assembly
popd

if [ ! -d "./migrator/lib" ]; then
mkdir migrator/lib
fi

cp ./spark-cassandra-connector/connector/target/scala-2.11/spark-cassandra-connector-assembly-*.jar ./migrator/lib
cp ./spark-kinesis/target/scala-2.11/spark-streaming-kinesis-asl-assembly-*.jar ./migrator/lib

sbt -Djava.io.tmpdir="$TMPDIR" migrator/assembly
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import org.apache.log4j.LogManager
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.kinesis.{
KinesisDynamoDBInputDStream,
KinesisInitialPositions,
KinesisInputDStream,
SparkAWSCredentials
}

Expand All @@ -29,22 +29,13 @@ object DynamoStreamReplication {
target: TargetSettings.DynamoDB,
targetTableDesc: TableDescription,
renames: List[Rename]): Unit =
KinesisInputDStream.builder
.streamingContext(streamingContext)
.streamName(src.table)
.dynamoStream(true)
.kinesisCredentials(
src.credentials.map {
case AWSCredentials(accessKey, secretKey) =>
SparkAWSCredentials.builder
.basicCredentials(accessKey, secretKey)
.build
}.orNull
)
.regionName(src.region.orNull)
.checkpointAppName(s"migrator_${src.table}_${System.currentTimeMillis()}")
.initialPosition(new KinesisInitialPositions.TrimHorizon)
.buildWithMessageHandler {
new KinesisDynamoDBInputDStream(
streamingContext,
streamName = src.table,
regionName = src.region.orNull,
initialPosition = new KinesisInitialPositions.TrimHorizon,
checkpointAppName = s"migrator_${src.table}_${System.currentTimeMillis()}",
messageHandler = {
case recAdapter: RecordAdapter =>
val rec = recAdapter.getInternalObject
val newMap = new util.HashMap[String, AttributeValue]()
Expand All @@ -63,35 +54,41 @@ object DynamoStreamReplication {
Some(newMap)

case _ => None
}
.foreachRDD { msgs =>
val rdd = msgs
.collect { case Some(item) => new DynamoDBItemWritable(item) }
.repartition(Runtime.getRuntime.availableProcessors() * 2)
.map(item => (new Text, item)) // Create the key after repartitioning to avoid Serialization issues
},
kinesisCreds = src.credentials.map {
case AWSCredentials(accessKey, secretKey) =>
SparkAWSCredentials.builder
.basicCredentials(accessKey, secretKey)
.build()
}.orNull
).foreachRDD { msgs =>
val rdd = msgs
.collect { case Some(item) => new DynamoDBItemWritable(item) }
.repartition(Runtime.getRuntime.availableProcessors() * 2)
.map(item => (new Text, item)) // Create the key after repartitioning to avoid Serialization issues

val changes =
rdd
.map(_._2) // Remove keys because they are not serializable
.groupBy { itemWritable =>
itemWritable.getItem.get(operationTypeColumn) match {
case `putOperation` => "UPSERT"
case `deleteOperation` => "DELETE"
case _ => "UNKNOWN"
}
val changes =
rdd
.map(_._2) // Remove keys because they are not serializable
.groupBy { itemWritable =>
itemWritable.getItem.get(operationTypeColumn) match {
case `putOperation` => "UPSERT"
case `deleteOperation` => "DELETE"
case _ => "UNKNOWN"
}
.mapValues(_.size)
.collect()
if (changes.nonEmpty) {
log.info("Changes to be applied:")
for ((operation, count) <- changes) {
log.info(s"${operation}: ${count}")
}
} else {
log.info("No changes to apply")
.mapValues(_.size)
.collect()
if (changes.nonEmpty) {
log.info("Changes to be applied:")
for ((operation, count) <- changes) {
log.info(s"${operation}: ${count}")
}

DynamoDB.writeRDD(target, renames, rdd, Some(targetTableDesc))(spark)
} else {
log.info("No changes to apply")
}

DynamoDB.writeRDD(target, renames, rdd, Some(targetTableDesc))(spark)
}

}
1 change: 0 additions & 1 deletion spark-kinesis
Submodule spark-kinesis deleted from d7ba51
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ import org.apache.spark.util.{Clock, SystemClock}
* @param workerId Worker Id of KCL worker for logging purposes
* @param clock In order to use ManualClocks for the purpose of testing
*/
private[kinesis] class KinesisCheckpointer(
receiver: KinesisReceiver[_],
private[kinesis] class KinesisDynamoDBCheckpointer(
receiver: KinesisDynamoDBReceiver[_],
checkpointInterval: Duration,
workerId: String,
clock: Clock = new SystemClock) extends Logging {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.apache.spark.streaming.kinesis

import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.streaming.kinesis.KinesisInputDStream.{DEFAULT_KINESIS_ENDPOINT_URL, DEFAULT_STORAGE_LEVEL}
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.StreamingContext

import scala.reflect.ClassTag

/**
* Override the default behavior of [[KinesisInputDStream]] to create a [[KinesisDynamoDBReceiver]].
*/
class KinesisDynamoDBInputDStream[T: ClassTag](
ssc: StreamingContext,
streamName: String,
regionName: String,
initialPosition: KinesisInitialPosition,
checkpointAppName: String,
messageHandler: Record => T,
kinesisCreds: SparkAWSCredentials
) extends KinesisInputDStream[T](
ssc,
streamName,
DEFAULT_KINESIS_ENDPOINT_URL,
regionName,
initialPosition,
checkpointAppName,
ssc.graph.batchDuration,
DEFAULT_STORAGE_LEVEL,
messageHandler,
kinesisCreds,
None,
None
) {

override def getReceiver(): Receiver[T] = {
new KinesisDynamoDBReceiver(
streamName,
endpointUrl,
regionName,
initialPosition,
checkpointAppName,
checkpointInterval,
DEFAULT_STORAGE_LEVEL,
messageHandler,
kinesisCreds,
dynamoDBCreds,
cloudWatchCreds
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.control.NonFatal

import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
import com.amazonaws.services.cloudwatch.AmazonCloudWatchClientBuilder
import com.amazonaws.services.dynamodbv2.{AmazonDynamoDBClientBuilder, AmazonDynamoDBStreamsClient}
import com.amazonaws.services.dynamodbv2.model.DescribeTableRequest
import com.amazonaws.services.dynamodbv2.streamsadapter.{AmazonDynamoDBStreamsAdapterClient, StreamsWorkerFactory}
import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
import com.amazonaws.services.kinesis.clientlibrary.interfaces.v2
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker}
import com.amazonaws.services.kinesis.model.Record

Expand Down Expand Up @@ -81,7 +86,7 @@ import org.apache.spark.util.Utils
* AWSCredentialsProvider passed to the KCL to authorize DynamoDB API calls.
* Will use kinesisCreds if value is None.
*/
private[kinesis] class KinesisReceiver[T](
private[kinesis] class KinesisDynamoDBReceiver[T](
val streamName: String,
endpointUrl: String,
regionName: String,
Expand Down Expand Up @@ -132,7 +137,7 @@ private[kinesis] class KinesisReceiver[T](
/**
* The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval.
*/
@volatile private var kinesisCheckpointer: KinesisCheckpointer = null
@volatile private var kinesisCheckpointer: KinesisDynamoDBCheckpointer = null

/**
* Latest sequence number ranges that have been stored successfully.
Expand All @@ -148,20 +153,38 @@ private[kinesis] class KinesisReceiver[T](

workerId = Utils.localHostName() + ":" + UUID.randomUUID()

kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
kinesisCheckpointer =
new KinesisDynamoDBCheckpointer(receiver, checkpointInterval, workerId)
val kinesisProvider = kinesisCreds.provider

val dynamoDBClient =
AmazonDynamoDBClientBuilder
.standard()
.withCredentials(dynamoDBCreds.fold(kinesisProvider)(_.provider))
.withRegion(regionName)
.build()

val actualStreamName =
dynamoDBClient
.describeTable(new DescribeTableRequest(streamName))
.getTable
.getLatestStreamArn

val kinesisClientLibConfiguration = {
val baseClientLibConfiguration = new KinesisClientLibConfiguration(
checkpointAppName,
streamName,
actualStreamName,
kinesisProvider,
dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider),
cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider),
workerId)
.withKinesisEndpoint(endpointUrl)
.withTaskBackoffTimeMillis(500)
.withRegionName(regionName)
.withIdleTimeBetweenReadsInMillis(500)
.withMaxRecords(1000)
.withFailoverTimeMillis(60000)
.withParentShardPollIntervalMillis(10000)

// Update the Kinesis client lib config with timestamp
// if InitialPositionInStream.AT_TIMESTAMP is passed
Expand All @@ -179,12 +202,36 @@ private[kinesis] class KinesisReceiver[T](
* IRecordProcessor.processRecords() method.
* We're using our custom KinesisRecordProcessor in this case.
*/
val recordProcessorFactory = new IRecordProcessorFactory {
override def createProcessor: IRecordProcessor =
new KinesisRecordProcessor(receiver, workerId)
val recordProcessorFactory = new v2.IRecordProcessorFactory {
override def createProcessor(): v2.IRecordProcessor =
new V1ToV2RecordProcessor(new KinesisDynamoDBRecordProcessor(receiver, workerId))
}

worker = {
val streamsAdapter = new AmazonDynamoDBStreamsAdapterClient(
AmazonDynamoDBStreamsClient.builder()
.withCredentials(kinesisProvider)
.withRegion(regionName)
.build()
)

val cloudWatchClient = AmazonCloudWatchClientBuilder.standard
.withCredentials(cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider))
.withRegion(regionName)
.withClientConfiguration(
kinesisClientLibConfiguration.getCloudWatchClientConfiguration
)
.build

StreamsWorkerFactory.createDynamoDbStreamsWorker(
recordProcessorFactory,
kinesisClientLibConfiguration,
streamsAdapter,
dynamoDBClient,
cloudWatchClient
)
}

worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration)
workerThread = new Thread() {
override def run(): Unit = {
try {
Expand Down
Loading

0 comments on commit a76cb48

Please sign in to comment.