diff --git a/README.md b/README.md index 3a58c45..d4102ce 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,9 @@ Refering $SPARK_HOME to the Spark installation directory. | kinesis.executor.recordMaxBufferedTime | 1000 (millis) | Specify the maximum buffered time of a record | | kinesis.executor.maxConnections | 1 | Specify the maximum connections to Kinesis | | kinesis.executor.aggregationEnabled | true | Specify if records should be aggregated before sending them to Kinesis | -| kniesis.executor.flushwaittimemillis | 100 | Wait time while flushing records to Kinesis on Task End | +| kinesis.executor.flushwaittimemillis | 100 | Wait time while flushing records to Kinesis on Task End | +| kinesis.executor.sink.bundle.records | false | Bundle records from one micro-batch into PutRecords request| +| kinesis.executor.sink.max.bundle.records | 500 | Max number of records in each PutRecords request | ## Roadmap * We need to migrate to DataSource V2 APIs for MicroBatchExecution. diff --git a/src/main/scala/org/apache/spark/sql/kinesis/CachedKinesisProducer.scala b/src/main/scala/org/apache/spark/sql/kinesis/CachedKinesisProducer.scala index 9db1991..8462443 100644 --- a/src/main/scala/org/apache/spark/sql/kinesis/CachedKinesisProducer.scala +++ b/src/main/scala/org/apache/spark/sql/kinesis/CachedKinesisProducer.scala @@ -22,16 +22,15 @@ import java.util.concurrent.{ExecutionException, TimeUnit} import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSStaticCredentialsProvider, BasicAWSCredentials} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging + import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesis import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration} import com.google.common.cache._ import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging - private[kinesis] object CachedKinesisProducer extends Logging { private type Producer = KinesisProducer @@ -69,6 +68,11 @@ private[kinesis] object CachedKinesisProducer extends Logging { .map { k => k.drop(8).toString -> producerConfiguration(k) } .toMap + val recordTtl = kinesisParams.getOrElse( + KinesisSourceProvider.SINK_RECORD_TTL, + KinesisSourceProvider.DEFAULT_SINK_RECORD_TTL) + .toLong + val recordMaxBufferedTime = kinesisParams.getOrElse( KinesisSourceProvider.SINK_RECORD_MAX_BUFFERED_TIME, KinesisSourceProvider.DEFAULT_SINK_RECORD_MAX_BUFFERED_TIME) @@ -123,6 +127,7 @@ private[kinesis] object CachedKinesisProducer extends Logging { } val kinesisProducer = new Producer(new KinesisProducerConfiguration() + .setRecordTtl(recordTtl) .setRecordMaxBufferedTime(recordMaxBufferedTime) .setMaxConnections(maxConnections) .setAggregationEnabled(aggregation) diff --git a/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala b/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala index cad76e9..9a91a68 100644 --- a/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala +++ b/src/main/scala/org/apache/spark/sql/kinesis/KinesisSourceProvider.scala @@ -51,10 +51,10 @@ private[kinesis] class KinesisSourceProvider extends DataSourceRegister */ override def sourceSchema( - sqlContext: SQLContext, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): (String, StructType) = { + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } validateStreamOptions(caseInsensitiveParams) require(schema.isEmpty, "Kinesis source has a fixed schema and cannot be set with a custom one") @@ -62,11 +62,11 @@ private[kinesis] class KinesisSourceProvider extends DataSourceRegister } override def createSource( - sqlContext: SQLContext, - metadataPath: String, - schema: Option[StructType], - providerName: String, - parameters: Map[String, String]): Source = { + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } @@ -138,7 +138,7 @@ private[kinesis] class KinesisSourceProvider extends DataSourceRegister "Sink endpoint url is a required field") } if (caseInsensitiveParams.contains(SINK_AGGREGATION_ENABLED) && ( - caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "true" && + caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "true" && caseInsensitiveParams(SINK_AGGREGATION_ENABLED).trim != "false" )) { throw new IllegalArgumentException( @@ -235,14 +235,17 @@ private[kinesis] object KinesisSourceProvider extends Logging { // Sink Options private[kinesis] val SINK_STREAM_NAME_KEY = "streamname" private[kinesis] val SINK_ENDPOINT_URL = "endpointurl" + private[kinesis] val SINK_RECORD_TTL = "kinesis.executor.recordTtl" private[kinesis] val SINK_RECORD_MAX_BUFFERED_TIME = "kinesis.executor.recordmaxbufferedtime" private[kinesis] val SINK_MAX_CONNECTIONS = "kinesis.executor.maxconnections" private[kinesis] val SINK_AGGREGATION_ENABLED = "kinesis.executor.aggregationenabled" - private[kinesis] val SINK_FLUSH_WAIT_TIME_MILLIS = "kniesis.executor.flushwaittimemillis" + private[kinesis] val SINK_FLUSH_WAIT_TIME_MILLIS = "kinesis.executor.flushwaittimemillis" + private[kinesis] val SINK_BUNDLE_RECORDS = "kinesis.executor.sink.bundle.records" + private[kinesis] val SINK_MAX_BUNDLE_RECORDS = "kinesis.executor.sink.max.bundle.records" - private[kinesis] def getKinesisPosition( - params: Map[String, String]): InitialKinesisPosition = { + + private[kinesis] def getKinesisPosition(params: Map[String, String]): InitialKinesisPosition = { val CURRENT_TIMESTAMP = System.currentTimeMillis params.get(STARTING_POSITION_KEY).map(_.trim) match { case Some(position) if position.toLowerCase(Locale.ROOT) == "latest" => @@ -262,6 +265,8 @@ private[kinesis] object KinesisSourceProvider extends Logging { private[kinesis] val DEFAULT_KINESIS_REGION_NAME: String = "us-east-1" + private[kinesis] val DEFAULT_SINK_RECORD_TTL: String = "30000" + private[kinesis] val DEFAULT_SINK_RECORD_MAX_BUFFERED_TIME: String = "1000" private[kinesis] val DEFAULT_SINK_MAX_CONNECTIONS: String = "1" @@ -269,7 +274,10 @@ private[kinesis] object KinesisSourceProvider extends Logging { private[kinesis] val DEFAULT_SINK_AGGREGATION: String = "true" private[kinesis] val DEFAULT_FLUSH_WAIT_TIME_MILLIS: String = "100" -} + private[kinesis] val DEFAULT_SINK_BUNDLE_RECORDS: String = "false" + private[kinesis] val DEFAULT_SINK_MAX_BUNDLE_RECORDS: String = "500" + +} diff --git a/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala b/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala index c4dcbfe..5c6f3b6 100644 --- a/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala +++ b/src/main/scala/org/apache/spark/sql/kinesis/KinesisWriteTask.scala @@ -43,11 +43,80 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin s"${KinesisSourceProvider.SINK_FLUSH_WAIT_TIME_MILLIS} has to be a positive integer") } + private val sinKBundleRecords = Try(producerConfiguration.getOrElse( + KinesisSourceProvider.SINK_BUNDLE_RECORDS, + KinesisSourceProvider.DEFAULT_SINK_BUNDLE_RECORDS).toBoolean).getOrElse { + throw new IllegalArgumentException( + s"${KinesisSourceProvider.SINK_BUNDLE_RECORDS} has to be a boolean value") + } + + private val maxBundleRecords = Try(producerConfiguration.getOrElse( + KinesisSourceProvider.SINK_MAX_BUNDLE_RECORDS, + KinesisSourceProvider.DEFAULT_SINK_MAX_BUNDLE_RECORDS).toInt).getOrElse { + throw new IllegalArgumentException( + s"${KinesisSourceProvider.SINK_MAX_BUNDLE_RECORDS} has to be a integer value") + } + private var failedWrite: Throwable = _ def execute(iterator: Iterator[InternalRow]): Unit = { + + if (sinKBundleRecords) { + bundleExecute(iterator) + } else { + singleExecute(iterator) + } + + } + + private def bundleExecute(iterator: Iterator[InternalRow]): Unit = { + + val groupedIterator: iterator.GroupedIterator[InternalRow] = iterator.grouped(maxBundleRecords) + + while (groupedIterator.hasNext) { + val rowList = groupedIterator.next() + sendBundledData(rowList) + } + + } + + private def sendBundledData(rowList: List[InternalRow]): Unit = { + producer = CachedKinesisProducer.getOrCreate(producerConfiguration) + + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + + override def onFailure(t: Throwable): Unit = { + if (failedWrite == null && t!= null) { + failedWrite = t + logError(s"Writing to $streamName failed due to ${t.getCause}") + } + } + + override def onSuccess(result: UserRecordResult): Unit = { + logDebug(s"Successfully put records: \n " + + s"sequenceNumber=${result.getSequenceNumber}, \n" + + s"shardId=${result.getShardId}, \n" + + s"attempts=${result.getAttempts.size}") + } + } + + for (r <- rowList) { + + val projectedRow = projection(r) + val partitionKey = projectedRow.getString(0) + val data = projectedRow.getBinary(1) + + val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data)) + + Futures.addCallback(future, kinesisCallBack) + + } + } + + private def singleExecute(iterator: Iterator[InternalRow]): Unit = { producer = CachedKinesisProducer.getOrCreate(producerConfiguration) + while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() val projectedRow = projection(currentRow) @@ -56,11 +125,10 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin sendData(partitionKey, data) } - } - def sendData(partitionKey: String, data: Array[Byte]): String = { - var sentSeqNumbers = new String + } + private def sendData(partitionKey: String, data: Array[Byte]): Unit = { val future = producer.addUserRecord(streamName, partitionKey, ByteBuffer.wrap(data)) val kinesisCallBack = new FutureCallback[UserRecordResult]() { @@ -73,14 +141,17 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin } override def onSuccess(result: UserRecordResult): Unit = { - val shardId = result.getShardId - sentSeqNumbers = result.getSequenceNumber + logDebug(s"Successfully put records: \n " + + s"sequenceNumber=${result.getSequenceNumber}, \n" + + s"shardId=${result.getShardId}, \n" + + s"attempts=${result.getAttempts.size}") } + } + Futures.addCallback(future, kinesisCallBack) producer.flushSync() - sentSeqNumbers } private def flushRecordsIfNecessary(): Unit = {