Skip to content

Commit 32a58a4

Browse files
committed
[SPARK-54027] Kafka Source RTM support
1 parent f0030ed commit 32a58a4

File tree

9 files changed

+1390
-14
lines changed

9 files changed

+1390
-14
lines changed

common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ public enum LogKeys implements LogKey {
823823
TIMEOUT,
824824
TIMER,
825825
TIMESTAMP,
826+
TIMESTAMP_COLUMN_NAME,
826827
TIME_UNITS,
827828
TIP,
828829
TOKEN,

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ package org.apache.spark.sql.kafka010
1919

2020
import java.{util => ju}
2121

22+
import org.apache.kafka.common.record.TimestampType
23+
2224
import org.apache.spark.TaskContext
23-
import org.apache.spark.internal.Logging
25+
import org.apache.spark.internal.{Logging, LogKeys}
2426
import org.apache.spark.internal.LogKeys._
2527
import org.apache.spark.sql.catalyst.InternalRow
2628
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2729
import org.apache.spark.sql.connector.metric.CustomTaskMetric
2830
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
31+
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead
32+
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus
2933
import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution}
30-
import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
34+
import org.apache.spark.sql.kafka010.consumer.{KafkaDataConsumer, KafkaDataConsumerIterator}
3135

3236
/** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */
3337
private[kafka010] case class KafkaBatchInputPartition(
@@ -67,7 +71,8 @@ private case class KafkaBatchPartitionReader(
6771
executorKafkaParams: ju.Map[String, Object],
6872
pollTimeoutMs: Long,
6973
failOnDataLoss: Boolean,
70-
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {
74+
includeHeaders: Boolean)
75+
extends SupportsRealTimeRead[InternalRow] with Logging {
7176

7277
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)
7378

@@ -77,6 +82,12 @@ private case class KafkaBatchPartitionReader(
7782

7883
private var nextOffset = rangeToRead.fromOffset
7984
private var nextRow: UnsafeRow = _
85+
private var iteratorForRealTimeMode: Option[KafkaDataConsumerIterator] = None
86+
87+
// Boolean flag that indicates whether we have logged the type of timestamp (i.e. create time,
88+
// log-append time, etc.) for the Kafka source. We log upon reading the first record, and we
89+
// then skip logging for subsequent records.
90+
private var timestampTypeLogged = false
8091

8192
override def next(): Boolean = {
8293
if (nextOffset < rangeToRead.untilOffset) {
@@ -93,6 +104,38 @@ private case class KafkaBatchPartitionReader(
93104
}
94105
}
95106

107+
override def nextWithTimeout(timeoutMs: java.lang.Long): RecordStatus = {
108+
if (!iteratorForRealTimeMode.isDefined) {
109+
logInfo(s"Getting a new kafka consuming iterator for ${offsetRange.topicPartition} " +
110+
s"starting from ${nextOffset}, timeoutMs ${timeoutMs}")
111+
iteratorForRealTimeMode = Some(consumer.getIterator(nextOffset))
112+
}
113+
assert(iteratorForRealTimeMode.isDefined)
114+
val nextRecord = iteratorForRealTimeMode.get.nextWithTimeout(timeoutMs)
115+
nextRecord.foreach { record =>
116+
117+
nextRow = unsafeRowProjector(record)
118+
nextOffset = record.offset + 1
119+
if (record.timestampType() == TimestampType.LOG_APPEND_TIME ||
120+
record.timestampType() == TimestampType.CREATE_TIME) {
121+
if (!timestampTypeLogged) {
122+
logInfo(log"Kafka source record timestamp type is " +
123+
log"${MDC(LogKeys.TIMESTAMP_COLUMN_NAME, record.timestampType())}")
124+
timestampTypeLogged = true
125+
}
126+
127+
RecordStatus.newStatusWithArrivalTimeMs(record.timestamp())
128+
} else {
129+
RecordStatus.newStatusWithoutArrivalTime(true)
130+
}
131+
}
132+
RecordStatus.newStatusWithoutArrivalTime(nextRecord.isDefined)
133+
}
134+
135+
override def getOffset(): KafkaSourcePartitionOffset = {
136+
KafkaSourcePartitionOffset(offsetRange.topicPartition, nextOffset)
137+
}
138+
96139
override def get(): UnsafeRow = {
97140
assert(nextRow != null)
98141
nextRow

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala

Lines changed: 126 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ private[kafka010] class KafkaMicroBatchStream(
6060
metadataPath: String,
6161
startingOffsets: KafkaOffsetRangeLimit,
6262
failOnDataLoss: Boolean)
63-
extends SupportsTriggerAvailableNow with ReportsSourceMetrics with MicroBatchStream with Logging {
63+
extends SupportsTriggerAvailableNow
64+
with SupportsRealTimeMode
65+
with ReportsSourceMetrics
66+
with MicroBatchStream
67+
with Logging {
6468

6569
private[kafka010] val pollTimeoutMs = options.getLong(
6670
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
@@ -93,6 +97,11 @@ private[kafka010] class KafkaMicroBatchStream(
9397

9498
private var isTriggerAvailableNow: Boolean = false
9599

100+
private var inRealTimeMode = false
101+
override def prepareForRealTimeMode(): Unit = {
102+
inRealTimeMode = true
103+
}
104+
96105
/**
97106
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
98107
* called in StreamExecutionThread. Otherwise, interrupting a thread while running
@@ -218,6 +227,93 @@ private[kafka010] class KafkaMicroBatchStream(
218227
}.toArray
219228
}
220229

230+
override def planInputPartitions(start: Offset): Array[InputPartition] = {
231+
// This function is used for real time mode. Trigger restrictions won't be supported.
232+
if (maxOffsetsPerTrigger.isDefined) {
233+
throw new UnsupportedOperationException(
234+
"maxOffsetsPerTrigger is not compatible with real time mode")
235+
}
236+
if (minOffsetPerTrigger.isDefined) {
237+
throw new UnsupportedOperationException(
238+
"minOffsetsPerTrigger is not compatible with real time mode"
239+
)
240+
}
241+
if (options.containsKey(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) {
242+
throw new UnsupportedOperationException(
243+
"minpartitions is not compatible with real time mode"
244+
)
245+
}
246+
if (options.containsKey(KafkaSourceProvider.ENDING_TIMESTAMP_OPTION_KEY)) {
247+
throw new UnsupportedOperationException(
248+
"endingtimestamp is not compatible with real time mode"
249+
)
250+
}
251+
if (options.containsKey(KafkaSourceProvider.MAX_TRIGGER_DELAY)) {
252+
throw new UnsupportedOperationException(
253+
"maxtriggerdelay is not compatible with real time mode"
254+
)
255+
}
256+
257+
// This function is used by Low Latency Mode, where we expect 1:1 mapping between a
258+
// topic partition and an input partition.
259+
// We are skipping partition range check for performance reason. We can always try to do
260+
// it in tasks if needed.
261+
val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
262+
263+
// Here we check previous topic partitions with latest partition offsets to see if we need to
264+
// update the partition list. Here we don't need the updated partition topic to be absolutely
265+
// up to date, because there might already be minutes' delay since new partition is created.
266+
// latestPartitionOffsets should be fetched not long ago anyway.
267+
// If the topic partitions change, we fetch the earliest offsets for all new partitions
268+
// and add them to the list.
269+
assert(latestPartitionOffsets != null, "latestPartitionOffsets should be set in latestOffset")
270+
val latestTopicPartitions = latestPartitionOffsets.keySet
271+
val newStartPartitionOffsets = if (startPartitionOffsets.keySet == latestTopicPartitions) {
272+
startPartitionOffsets
273+
} else {
274+
val newPartitions = latestTopicPartitions.diff(startPartitionOffsets.keySet)
275+
// Instead of fetching earliest offsets, we could fill offset 0 here and avoid this extra
276+
// admin function call. But we consider new partition is rare and getting earliest offset
277+
// aligns with what we do in micro-batch mode and can potentially enable more sanity checks
278+
// in executor side.
279+
val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
280+
281+
assert(
282+
newPartitionOffsets.keys.forall(!startPartitionOffsets.contains(_)),
283+
"startPartitionOffsets should not contain any key in newPartitionOffsets")
284+
285+
// Filter out new partition offsets that are not 0 and log a warning
286+
val nonZeroNewPartitionOffsets = newPartitionOffsets.filter {
287+
case (_, offset) => offset != 0
288+
}
289+
// Log the non-zero new partition offsets
290+
if (nonZeroNewPartitionOffsets.nonEmpty) {
291+
logWarning(log"new partitions should start from offset 0: " +
292+
log"${MDC(OFFSETS, nonZeroNewPartitionOffsets)}")
293+
}
294+
295+
logInfo(log"Added new partition offsets: ${MDC(OFFSETS, newPartitionOffsets)}")
296+
startPartitionOffsets ++ newPartitionOffsets
297+
}
298+
299+
newStartPartitionOffsets.keySet.toSeq.map { tp =>
300+
val fromOffset = newStartPartitionOffsets(tp)
301+
KafkaBatchInputPartition(
302+
KafkaOffsetRange(tp, fromOffset, Long.MaxValue, preferredLoc = None),
303+
executorKafkaParams,
304+
pollTimeoutMs,
305+
failOnDataLoss,
306+
includeHeaders)
307+
}.toArray
308+
}
309+
310+
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
311+
val mergedMap = offsets.map {
312+
case KafkaSourcePartitionOffset(p, o) => (p, o)
313+
}.toMap
314+
KafkaSourceOffset(mergedMap)
315+
}
316+
221317
override def createReaderFactory(): PartitionReaderFactory = {
222318
KafkaBatchReaderFactory
223319
}
@@ -235,7 +331,30 @@ private[kafka010] class KafkaMicroBatchStream(
235331
override def toString(): String = s"KafkaV2[$kafkaOffsetReader]"
236332

237333
override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String, String] = {
238-
KafkaMicroBatchStream.metrics(latestConsumedOffset, latestPartitionOffsets)
334+
var rtmFetchLatestOffsetsTimeMs = Option.empty[Long]
335+
val reCalculatedLatestPartitionOffsets =
336+
if (inRealTimeMode) {
337+
if (!latestConsumedOffset.isPresent) {
338+
// this means a batch has no end offsets, which should not happen
339+
None
340+
} else {
341+
Some {
342+
val startTime = System.currentTimeMillis()
343+
val latestOffsets = kafkaOffsetReader.fetchLatestOffsets(
344+
Some(latestConsumedOffset.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets))
345+
val endTime = System.currentTimeMillis()
346+
rtmFetchLatestOffsetsTimeMs = Some(endTime - startTime)
347+
latestOffsets
348+
}
349+
}
350+
} else {
351+
// If we are in micro-batch mode, we need to get the latest partition offsets at the
352+
// start of the batch and recalculate the latest offsets at the end for backlog
353+
// estimation.
354+
Some(kafkaOffsetReader.fetchLatestOffsets(Some(latestPartitionOffsets)))
355+
}
356+
357+
KafkaMicroBatchStream.metrics(latestConsumedOffset, reCalculatedLatestPartitionOffsets)
239358
}
240359

241360
/**
@@ -386,13 +505,14 @@ object KafkaMicroBatchStream extends Logging {
386505
*/
387506
def metrics(
388507
latestConsumedOffset: Optional[Offset],
389-
latestAvailablePartitionOffsets: PartitionOffsetMap): ju.Map[String, String] = {
508+
latestAvailablePartitionOffsets: Option[PartitionOffsetMap]): ju.Map[String, String] = {
390509
val offset = Option(latestConsumedOffset.orElse(null))
391510

392-
if (offset.nonEmpty && latestAvailablePartitionOffsets != null) {
511+
if (offset.nonEmpty && latestAvailablePartitionOffsets.isDefined) {
393512
val consumedPartitionOffsets = offset.map(KafkaSourceOffset(_)).get.partitionToOffsets
394-
val offsetsBehindLatest = latestAvailablePartitionOffsets
395-
.map(partitionOffset => partitionOffset._2 - consumedPartitionOffsets(partitionOffset._1))
513+
val offsetsBehindLatest = latestAvailablePartitionOffsets.get
514+
.map(partitionOffset => partitionOffset._2 -
515+
consumedPartitionOffsets.getOrElse(partitionOffset._1, 0L))
396516
if (offsetsBehindLatest.nonEmpty) {
397517
val avgOffsetBehindLatest = offsetsBehindLatest.sum.toDouble / offsetsBehindLatest.size
398518
return Map[String, String](

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ private[kafka010] class InternalKafkaConsumer(
6363
private[consumer] var kafkaParamsWithSecurity: ju.Map[String, Object] = _
6464
private val consumer = createConsumer()
6565

66+
def poll(pollTimeoutMs: Long): ju.List[ConsumerRecord[Array[Byte], Array[Byte]]] = {
67+
val p = consumer.poll(Duration.ofMillis(pollTimeoutMs))
68+
val r = p.records(topicPartition)
69+
logDebug(s"Polled $groupId ${p.partitions()} ${r.size}")
70+
r
71+
}
72+
6673
/**
6774
* Poll messages from Kafka starting from `offset` and returns a pair of "list of consumer record"
6875
* and "offset after poll". The list of consumer record may be empty if the Kafka consumer fetches
@@ -131,7 +138,7 @@ private[kafka010] class InternalKafkaConsumer(
131138
c
132139
}
133140

134-
private def seek(offset: Long): Unit = {
141+
def seek(offset: Long): Unit = {
135142
logDebug(s"Seeking to $groupId $topicPartition $offset")
136143
consumer.seek(topicPartition, offset)
137144
}
@@ -228,6 +235,19 @@ private[consumer] case class FetchedRecord(
228235
}
229236
}
230237

238+
/**
239+
* This class keeps returning the next records. If no new record is available, it will keep
240+
* polling until timeout. It is used by KafkaBatchPartitionReader.nextWithTimeout(), to reduce
241+
* seeking overhead in real time mode.
242+
*/
243+
private[sql] trait KafkaDataConsumerIterator {
244+
/**
245+
* Return the next record
246+
* @return None if no new record is available after `timeoutMs`.
247+
*/
248+
def nextWithTimeout(timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]]
249+
}
250+
231251
/**
232252
* This class helps caller to read from Kafka leveraging consumer pool as well as fetched data pool.
233253
* This class throws error when data loss is detected while reading from Kafka.
@@ -272,6 +292,82 @@ private[kafka010] class KafkaDataConsumer(
272292
// Starting timestamp when the consumer is created.
273293
private var startTimestampNano: Long = System.nanoTime()
274294

295+
/**
296+
* Get an iterator that can return the next entry. It is used exclusively for real-time
297+
* mode.
298+
*
299+
* It is called by KafkaBatchPartitionReader.nextWithTimeout(). Unlike get(), there is no
300+
* out-of-bound check in this function. Since there is no endOffset given, we assume anything
301+
* record is valid to return as long as it is at or after `offset`.
302+
*
303+
* @param startOffsets, the starting positions to read from, inclusive.
304+
*/
305+
def getIterator(offset: Long): KafkaDataConsumerIterator = {
306+
new KafkaDataConsumerIterator {
307+
private var fetchedRecordList
308+
: Option[ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]]] = None
309+
private val consumer = getOrRetrieveConsumer()
310+
private var firstRecord = true
311+
private var _currentOffset: Long = offset - 1
312+
313+
private def fetchedRecordListHasNext(): Boolean = {
314+
fetchedRecordList.map(_.hasNext).getOrElse(false)
315+
}
316+
317+
override def nextWithTimeout(
318+
timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]] = {
319+
var timeLeftMs = timeoutMs
320+
321+
def timeAndDeductFromTimeLeftMs[T](body: => T): Unit = {
322+
// To reduce timing the same operator twice, we reuse the timing results for
323+
// totalTimeReadNanos and for timeoutMs.
324+
val prevTime = totalTimeReadNanos
325+
timeNanos {
326+
body
327+
}
328+
timeLeftMs -= (totalTimeReadNanos - prevTime) / 1000000
329+
}
330+
331+
if (firstRecord) {
332+
timeAndDeductFromTimeLeftMs {
333+
consumer.seek(offset)
334+
firstRecord = false
335+
}
336+
}
337+
while (!fetchedRecordListHasNext() && timeLeftMs > 0) {
338+
timeAndDeductFromTimeLeftMs {
339+
try {
340+
val records = consumer.poll(timeLeftMs)
341+
numPolls += 1
342+
if (!records.isEmpty) {
343+
numRecordsPolled += records.size
344+
fetchedRecordList = Some(records.listIterator)
345+
}
346+
} catch {
347+
case ex: OffsetOutOfRangeException =>
348+
if (_currentOffset != -1) {
349+
throw ex
350+
} else {
351+
Thread.sleep(10) // retry until the source partition is populated
352+
assert(offset == 0)
353+
consumer.seek(offset)
354+
}
355+
}
356+
}
357+
}
358+
if (fetchedRecordListHasNext()) {
359+
totalRecordsRead += 1
360+
val nextRecord = fetchedRecordList.get.next()
361+
assert(nextRecord.offset > _currentOffset, "Kafka offset should be incremental.")
362+
_currentOffset = nextRecord.offset
363+
Some(nextRecord)
364+
} else {
365+
None
366+
}
367+
}
368+
}
369+
}
370+
275371
/**
276372
* Get the record for the given offset if available.
277373
*

0 commit comments

Comments
 (0)