Skip to content
205 changes: 201 additions & 4 deletions core/src/main/scala/kafka/server/AutoTopicCreationManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package kafka.server

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.locks.ReentrantLock
import java.util.{Collections, Properties}
import kafka.coordinator.transaction.TransactionCoordinator
import kafka.utils.Logging
Expand All @@ -35,6 +36,7 @@ import org.apache.kafka.coordinator.share.ShareCoordinator
import org.apache.kafka.coordinator.transaction.TransactionLogConfig
import org.apache.kafka.server.common.{ControllerRequestCompletionHandler, NodeToControllerChannelManager}
import org.apache.kafka.server.quota.ControllerMutationQuota
import org.apache.kafka.common.utils.Time

import scala.collection.{Map, Seq, Set, mutable}
import scala.jdk.CollectionConverters._
Expand All @@ -50,21 +52,111 @@ trait AutoTopicCreationManager {

def createStreamsInternalTopics(
topics: Map[String, CreatableTopic],
requestContext: RequestContext
requestContext: RequestContext,
timeoutMs: Long
): Unit

def getStreamsInternalTopicCreationErrors(
topicNames: Set[String],
currentTimeMs: Long
): Map[String, String]

def close(): Unit = {}

}

/**
* Thread-safe cache that stores topic creation errors with per-entry expiration.
* - Expiration: maintained by a min-heap (priority queue) on expiration time
* - Capacity: enforced by insertion-order removal (keeps the most recently inserted entries)
*/
private[server] class ExpiringErrorCache(maxSize: Int, time: Time) {

private case class Entry(topicName: String, errorMessage: String, expirationTimeMs: Long)

private val byTopic = new ConcurrentHashMap[String, Entry]()
private val expiryQueue = new java.util.PriorityQueue[Entry](11, new java.util.Comparator[Entry] {
override def compare(a: Entry, b: Entry): Int = java.lang.Long.compare(a.expirationTimeMs, b.expirationTimeMs)
})
private val lock = new ReentrantLock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make byTopic a ConcurrentHashMap and use this lock as a write lock only?
That is, make the read path lock contention free? That would mean we can only expire on the put path, which should be fine. However, then we may read expired entries when getting from the map, so in get you need to check if the returned entry is expired before returning it.


def put(topicName: String, errorMessage: String, ttlMs: Long): Unit = {
lock.lock()
try {
val existing = byTopic.get(topicName)
if (existing != null) {
// Remove old instance from structures
expiryQueue.remove(existing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This remove is a linear time operation, right? I think we should avoid that. I think it may be fine to just leave it in the expiryQueue, since once it expired, we will no deletethe key from the map if the new value was replaced.

}

val currentTimeMs = time.milliseconds()
val expirationTimeMs = currentTimeMs + ttlMs
val entry = Entry(topicName, errorMessage, expirationTimeMs)
byTopic.put(topicName, entry)
expiryQueue.add(entry)

// Clean up expired entries
while (!expiryQueue.isEmpty && expiryQueue.peek().expirationTimeMs <= currentTimeMs) {
val expired = expiryQueue.poll()
val current = byTopic.get(expired.topicName)
if (current != null && (current eq expired)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is eq doing a deep comparison here? Maybe it would be enough to compare the timestamps, the deep comparison is expensive

byTopic.remove(expired.topicName)
}
}

// Enforce capacity by removing entries with earliest expiration time first
while (byTopic.size() > maxSize && !expiryQueue.isEmpty) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you merge this loop into the loop above by just checking the condition

!expiryQueue.isEmpty && (expiryQueue.peek().expirationTimeMs <= currentTimeMs || byTopic.size() > maxSize)

in the while loop?

val evicted = expiryQueue.poll()
if (evicted != null) {
val current = byTopic.get(evicted.topicName)
if (current != null && (current eq evicted)) {
byTopic.remove(evicted.topicName)
}
}
}
} finally {
lock.unlock()
}
}

def getErrorsForTopics(topicNames: Set[String], currentTimeMs: Long): Map[String, String] = {
val result = mutable.Map.empty[String, String]
topicNames.foreach { topicName =>
val entry = byTopic.get(topicName)
if (entry != null && entry.expirationTimeMs > currentTimeMs) {
result.put(topicName, entry.errorMessage)
}
}
result.toMap
}

private[server] def clear(): Unit = {
lock.lock()
try {
byTopic.clear()
expiryQueue.clear()
} finally {
lock.unlock()
}
}
}


class DefaultAutoTopicCreationManager(
config: KafkaConfig,
channelManager: NodeToControllerChannelManager,
groupCoordinator: GroupCoordinator,
txnCoordinator: TransactionCoordinator,
shareCoordinator: ShareCoordinator
shareCoordinator: ShareCoordinator,
time: Time,
topicErrorCacheCapacity: Int = 1000
) extends AutoTopicCreationManager with Logging {

private val inflightTopics = Collections.newSetFromMap(new ConcurrentHashMap[String, java.lang.Boolean]())

// Hardcoded default capacity; can be overridden in tests via constructor param
private val topicCreationErrorCache = new ExpiringErrorCache(topicErrorCacheCapacity, time)

/**
* Initiate auto topic creation for the given topics.
*
Expand Down Expand Up @@ -93,13 +185,21 @@ class DefaultAutoTopicCreationManager(

override def createStreamsInternalTopics(
topics: Map[String, CreatableTopic],
requestContext: RequestContext
requestContext: RequestContext,
timeoutMs: Long
): Unit = {
if (topics.nonEmpty) {
sendCreateTopicRequest(topics, Some(requestContext))
sendCreateTopicRequestWithErrorCaching(topics, Some(requestContext), timeoutMs)
}
}

override def getStreamsInternalTopicCreationErrors(
topicNames: Set[String],
currentTimeMs: Long
): Map[String, String] = {
topicCreationErrorCache.getErrorsForTopics(topicNames, currentTimeMs)
}

private def sendCreateTopicRequest(
creatableTopics: Map[String, CreatableTopic],
requestContext: Option[RequestContext]
Expand Down Expand Up @@ -264,4 +364,101 @@ class DefaultAutoTopicCreationManager(

(creatableTopics, uncreatableTopics)
}

private def sendCreateTopicRequestWithErrorCaching(
creatableTopics: Map[String, CreatableTopic],
requestContext: Option[RequestContext],
timeoutMs: Long
): Seq[MetadataResponseTopic] = {
val topicsToCreate = new CreateTopicsRequestData.CreatableTopicCollection(creatableTopics.size)
topicsToCreate.addAll(creatableTopics.values.asJavaCollection)

val createTopicsRequest = new CreateTopicsRequest.Builder(
new CreateTopicsRequestData()
.setTimeoutMs(config.requestTimeoutMs)
.setTopics(topicsToCreate)
)

val requestCompletionHandler = new ControllerRequestCompletionHandler {
override def onTimeout(): Unit = {
clearInflightRequests(creatableTopics)
debug(s"Auto topic creation timed out for ${creatableTopics.keys}.")
cacheTopicCreationErrors(creatableTopics.keys.toSet, "Auto topic creation timed out.", timeoutMs)
}

override def onComplete(response: ClientResponse): Unit = {
clearInflightRequests(creatableTopics)
if (response.authenticationException() != null) {
val authException = response.authenticationException()
warn(s"Auto topic creation failed for ${creatableTopics.keys} with authentication exception: ${authException.getMessage}")
cacheTopicCreationErrors(creatableTopics.keys.toSet, authException.getMessage, timeoutMs)
} else if (response.versionMismatch() != null) {
val versionException = response.versionMismatch()
warn(s"Auto topic creation failed for ${creatableTopics.keys} with version mismatch exception: ${versionException.getMessage}")
cacheTopicCreationErrors(creatableTopics.keys.toSet, versionException.getMessage, timeoutMs)
} else {
response.responseBody() match {
case createTopicsResponse: CreateTopicsResponse =>
cacheTopicCreationErrorsFromResponse(createTopicsResponse, timeoutMs)
case _ =>
debug(s"Auto topic creation completed for ${creatableTopics.keys} with response ${response.responseBody}.")
}
}
}
}

val request = requestContext.map { context =>
val requestVersion =
channelManager.controllerApiVersions.toScala match {
case None =>
// We will rely on the Metadata request to be retried in the case
// that the latest version is not usable by the controller.
ApiKeys.CREATE_TOPICS.latestVersion()
case Some(nodeApiVersions) =>
nodeApiVersions.latestUsableVersion(ApiKeys.CREATE_TOPICS)
}

// Borrow client information such as client id and correlation id from the original request,
// in order to correlate the create request with the original metadata request.
val requestHeader = new RequestHeader(ApiKeys.CREATE_TOPICS,
requestVersion,
context.clientId,
context.correlationId)
ForwardingManager.buildEnvelopeRequest(context,
createTopicsRequest.build(requestVersion).serializeWithHeader(requestHeader))
}.getOrElse(createTopicsRequest)

channelManager.sendRequest(request, requestCompletionHandler)

val creatableTopicResponses = creatableTopics.keySet.toSeq.map { topic =>
new MetadataResponseTopic()
.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code)
.setName(topic)
.setIsInternal(Topic.isInternal(topic))
}

creatableTopicResponses
}

private def cacheTopicCreationErrors(topicNames: Set[String], errorMessage: String, ttlMs: Long): Unit = {
topicNames.foreach { topicName =>
topicCreationErrorCache.put(topicName, errorMessage, ttlMs)
}
}

private def cacheTopicCreationErrorsFromResponse(response: CreateTopicsResponse, ttlMs: Long): Unit = {
response.data().topics().forEach { topicResult =>
if (topicResult.errorCode() != Errors.NONE.code()) {
val errorMessage = Option(topicResult.errorMessage())
.filter(_.nonEmpty)
.getOrElse(Errors.forCode(topicResult.errorCode()).message())
topicCreationErrorCache.put(topicResult.name(), errorMessage, ttlMs)
debug(s"Cached topic creation error for ${topicResult.name()}: $errorMessage")
}
}
}

override def close(): Unit = {
topicCreationErrorCache.clear()
}
}
5 changes: 4 additions & 1 deletion core/src/main/scala/kafka/server/BrokerServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class BrokerServer(

autoTopicCreationManager = new DefaultAutoTopicCreationManager(
config, clientToControllerChannelManager, groupCoordinator,
transactionCoordinator, shareCoordinator)
transactionCoordinator, shareCoordinator, time)

dynamicConfigHandlers = Map[ConfigType, ConfigHandler](
ConfigType.TOPIC -> new TopicConfigHandler(replicaManager, config, quotaManagers),
Expand Down Expand Up @@ -781,6 +781,9 @@ class BrokerServer(
if (shareCoordinator != null)
CoreUtils.swallow(shareCoordinator.shutdown(), this)

if (autoTopicCreationManager != null)
CoreUtils.swallow(autoTopicCreationManager.close(), this)

if (assignmentsManager != null)
CoreUtils.swallow(assignmentsManager.close(), this)

Expand Down
29 changes: 27 additions & 2 deletions core/src/main/scala/kafka/server/KafkaApis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2888,10 +2888,35 @@ class KafkaApis(val requestChannel: RequestChannel,
)
}
} else {
autoTopicCreationManager.createStreamsInternalTopics(topicsToCreate, requestContext);
// Compute group-specific timeout for caching errors (2 * heartbeat interval)
val heartbeatIntervalMs = Option(groupConfigManager.groupConfig(streamsGroupHeartbeatRequest.data.groupId).orElse(null))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TTL is calculated as 2 × heartbeat interval, but the PR description mentions 3 × request.timeout.ms. This inconsistency could be confusing. Can you please fix the PR description? I would keep it much shorter and less AI generated so that it is easier to keep up-to-date.

.map(_.streamsHeartbeatIntervalMs().toLong)
.getOrElse(config.groupCoordinatorConfig.streamsGroupHeartbeatIntervalMs().toLong)
val timeoutMs = heartbeatIntervalMs * 2

autoTopicCreationManager.createStreamsInternalTopics(topicsToCreate, requestContext, timeoutMs)

// Check for cached topic creation errors only if there's already a MISSING_INTERNAL_TOPICS status
val hasMissingInternalTopicsStatus = responseData.status() != null &&
responseData.status().stream().anyMatch(s => s.statusCode() == StreamsGroupHeartbeatResponse.Status.MISSING_INTERNAL_TOPICS.code())

if (hasMissingInternalTopicsStatus) {
val currentTimeMs = time.milliseconds()
val cachedErrors = autoTopicCreationManager.getStreamsInternalTopicCreationErrors(topicsToCreate.keys.toSet, currentTimeMs)
if (cachedErrors.nonEmpty) {
val missingInternalTopicStatus =
responseData.status().stream().filter(x => x.statusCode() == StreamsGroupHeartbeatResponse.Status.MISSING_INTERNAL_TOPICS.code()).findFirst()
val creationErrorDetails = cachedErrors.map { case (topic, error) => s"$topic ($error)" }.mkString(", ")
if (missingInternalTopicStatus.isPresent) {
val existingDetail = Option(missingInternalTopicStatus.get().statusDetail()).getOrElse("")
missingInternalTopicStatus.get().setStatusDetail(
existingDetail + s"; Creation failed: $creationErrorDetails."
)
}
}
}
}
}

requestHelper.sendMaybeThrottle(request, new StreamsGroupHeartbeatResponse(responseData))
}
}
Expand Down
Loading
Loading