Skip to content

[SPARK-53294][SS] Enable StateDataSource with state checkpoint v2 (only batchId option) #52047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ case class StateSourceOptions(
readChangeFeedOptions: Option[ReadChangeFeedOptions],
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean) {
flattenCollectionTypes: Boolean,
operatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand Down Expand Up @@ -567,10 +568,37 @@ object StateSourceOptions extends DataSourceOptions {
}
}

val startBatchId = if (fromSnapshotOptions.isDefined) {
fromSnapshotOptions.get.snapshotStartBatchId
} else if (readChangeFeedOptions.isDefined) {
readChangeFeedOptions.get.changeStartBatchId
} else {
batchId.get
}

val operatorStateUniqueIds = getOperatorStateUniqueIds(
sparkSession,
startBatchId,
operatorId,
resolvedCpLocation)

if (operatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
if (readChangeFeedOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
READ_CHANGE_FEED,
"Read change feed is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes)
stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds)
}

private def resolvedCheckpointLocation(
Expand All @@ -589,6 +617,20 @@ object StateSourceOptions extends DataSourceOptions {
}
}

private def getOperatorStateUniqueIds(
session: SparkSession,
batchId: Long,
operatorId: Long,
checkpointLocation: String): Option[Array[Array[String]]] = {
val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog
val commitMetadata = commitLog.get(batchId) match {
case Some(commitMetadata) => commitMetadata
case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
}

commitMetadata.stateUniqueIds.flatMap(_.get(operatorId))
}

// Modifies options due to external data. Returns modified options.
// If this is a join operator specifying a store name using state format v3,
// we need to modify the options.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
Expand Down Expand Up @@ -95,6 +96,13 @@ abstract class StatePartitionReaderBase(
schema, "value").asInstanceOf[StructType]
}

protected val getStoreUniqueId : Option[String] = {
SymmetricHashJoinStateManager.getStateStoreCheckpointId(
storeName = partition.sourceOptions.storeName,
partitionId = partition.partition,
stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds)
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
Expand All @@ -113,7 +121,9 @@ abstract class StatePartitionReaderBase(
val isInternal = partition.sourceOptions.readRegisteredTimers

if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
Expand Down Expand Up @@ -171,7 +181,11 @@ class StatePartitionReader(

private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)
case None =>
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId
)

case Some(fromSnapshotOptions) =>
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ class StreamStreamJoinStatePartitionReader(
throw StateDataSourceErrors.internalError("Unexpected join side for stream-stream read!")
}

private val usesVirtualColumnFamilies = StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)

private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyToNumValues
} else {
stateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyWithIndexToValue
} else {
stateStoreCheckpointIds.right.keyWithIndexToValue
}

/*
* This is to handle the difference of schema across state format versions. The major difference
* is whether we have added new field(s) in addition to the fields from input schema.
Expand All @@ -85,10 +107,7 @@ class StreamStreamJoinStatePartitionReader(
// column from the value schema to get the actual fields.
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
// If checkpoint is using one store and virtual column families, version is 3
if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)) {
if (usesVirtualColumnFamilies) {
(valueSchema.dropRight(1), 3)
} else {
(valueSchema.dropRight(1), 2)
Expand Down Expand Up @@ -130,8 +149,8 @@ class StreamStreamJoinStatePartitionReader(
storeConf = storeConf,
hadoopConf = hadoopConf.value,
partitionId = partition.partition,
keyToNumValuesStateStoreCkptId = None,
keyWithIndexToValueStateStoreCkptId = None,
keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ case class StreamingSymmetricHashJoinExec(

assert(stateInfo.isDefined, "State info not defined")
val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partitionId, stateInfo.get, useVirtualColumnFamilies)
partitionId, stateInfo.get.stateStoreCkptIds, useVirtualColumnFamilies)

val inputSchema = left.output ++ right.output
val postJoinFilter =
Expand All @@ -363,12 +363,12 @@ case class StreamingSymmetricHashJoinExec(
new OneSideHashJoiner(
LeftSide, left.output, leftKeys, leftInputIter,
condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left, partitionId,
checkpointIds.left.keyToNumValues, checkpointIds.left.valueToNumKeys,
checkpointIds.left.keyToNumValues, checkpointIds.left.keyWithIndexToValue,
skippedNullValueCount, joinStateManagerStoreGenerator),
new OneSideHashJoiner(
RightSide, right.output, rightKeys, rightInputIter,
condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right, partitionId,
checkpointIds.right.keyToNumValues, checkpointIds.right.valueToNumKeys,
checkpointIds.right.keyToNumValues, checkpointIds.right.keyWithIndexToValue,
skippedNullValueCount, joinStateManagerStoreGenerator))

// Join one side input using the other side's buffered/state rows. Here is how it is done.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ object StreamingSymmetricHashJoinHelper extends Logging {

case class JoinerStateStoreCkptInfo(
keyToNumValues: StateStoreCheckpointInfo,
valueToNumKeys: StateStoreCheckpointInfo)
keyWithIndexToValue: StateStoreCheckpointInfo)

case class JoinStateStoreCkptInfo(
left: JoinerStateStoreCkptInfo,
right: JoinerStateStoreCkptInfo)

case class JoinerStateStoreCheckpointId(
keyToNumValues: Option[String],
valueToNumKeys: Option[String])
keyToNumValues: Option[String],
keyWithIndexToValue: Option[String])

case class JoinStateStoreCheckpointId(
left: JoinerStateStoreCheckpointId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1135,17 +1135,17 @@ object SymmetricHashJoinStateManager {
val ckptIds = joinCkptInfo.left.keyToNumValues.stateStoreCkptId.map(
Array(
_,
joinCkptInfo.left.valueToNumKeys.stateStoreCkptId.get,
joinCkptInfo.left.keyWithIndexToValue.stateStoreCkptId.get,
joinCkptInfo.right.keyToNumValues.stateStoreCkptId.get,
joinCkptInfo.right.valueToNumKeys.stateStoreCkptId.get
joinCkptInfo.right.keyWithIndexToValue.stateStoreCkptId.get
)
)
val baseCkptIds = joinCkptInfo.left.keyToNumValues.baseStateStoreCkptId.map(
Array(
_,
joinCkptInfo.left.valueToNumKeys.baseStateStoreCkptId.get,
joinCkptInfo.left.keyWithIndexToValue.baseStateStoreCkptId.get,
joinCkptInfo.right.keyToNumValues.baseStateStoreCkptId.get,
joinCkptInfo.right.valueToNumKeys.baseStateStoreCkptId.get
joinCkptInfo.right.keyWithIndexToValue.baseStateStoreCkptId.get
)
)

Expand All @@ -1162,31 +1162,68 @@ object SymmetricHashJoinStateManager {
* mergeStateStoreCheckpointInfo(). This function is used to read it back into individual state
* store checkpoint IDs.
* @param partitionId
* @param stateInfo
* @param stateStoreCkptIds
* @param useColumnFamiliesForJoins
* @return
*/
def getStateStoreCheckpointIds(
partitionId: Int,
stateInfo: StatefulOperatorStateInfo,
stateStoreCkptIds: Option[Array[Array[String]]],
useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = {
if (useColumnFamiliesForJoins) {
val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head)
val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head)
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt),
right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt)
left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt),
right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, keyWithIndexToValue = ckpt)
)
} else {
val stateStoreCkptIds = stateInfo.stateStoreCkptIds
val stateStoreCkptIdsOpt = stateStoreCkptIds
.map(_(partitionId))
.map(_.map(Option(_)))
.getOrElse(Array.fill[Option[String]](4)(None))
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(0),
valueToNumKeys = stateStoreCkptIds(1)),
keyToNumValues = stateStoreCkptIdsOpt(0),
keyWithIndexToValue = stateStoreCkptIdsOpt(1)),
right = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(2),
valueToNumKeys = stateStoreCkptIds(3)))
keyToNumValues = stateStoreCkptIdsOpt(2),
keyWithIndexToValue = stateStoreCkptIdsOpt(3)))
}
}

/**
* Stream-stream join has 4 state stores instead of one. So it will generate 4 different
* checkpoint IDs when not using virtual column families.
* This function is used to get the checkpoint ID for a specific state store
* by the name of the store, partition ID and the checkpoint IDs array.
* @param storeName
* @param partitionId
* @param stateStoreCkptIds
* @param useColumnFamiliesForJoins
* @return
*/
def getStateStoreCheckpointId(
storeName: String,
partitionId: Int,
stateStoreCkptIds: Option[Array[Array[String]]],
useColumnFamiliesForJoins: Boolean = false) : Option[String] = {
if (useColumnFamiliesForJoins || storeName == StateStoreId.DEFAULT_STORE_NAME) {
stateStoreCkptIds.map(_(partitionId)).map(_.head)
} else {
val joinStateStoreCkptIds = getStateStoreCheckpointIds(
partitionId, stateStoreCkptIds, useColumnFamiliesForJoins)

if (storeName == getStateStoreName(LeftSide, KeyToNumValuesType)) {
joinStateStoreCkptIds.left.keyToNumValues
} else if (storeName == getStateStoreName(RightSide, KeyToNumValuesType)) {
joinStateStoreCkptIds.right.keyToNumValues
} else if (storeName == getStateStoreName(LeftSide, KeyWithIndexToValueType)) {
joinStateStoreCkptIds.left.keyWithIndexToValue
} else if (storeName == getStateStoreName(RightSide, KeyWithIndexToValueType)) {
joinStateStoreCkptIds.right.keyWithIndexToValue
} else {
None
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,66 @@ StateDataSourceReadSuite {
}
}

class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceReadSuite {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we add the operator specific tests ?

  • aggregations
  • dedup
  • join
  • transformWithState etc ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extending StateDataSourceReadSuite will run all the tests in StateDataSourceReadSuite with the config set in beforeAll. This pattern already exists in this suite (example). I also attached the result of running this command to the PR description:

testOnly *RocksDBWithCheckpointV2StateDataSourceReaderSuite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future when changeStartBatchId is added we will have to generate the golden files for those tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you extend RocksDBStateDataSourceReadSuite here instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extending RocksDBStateDataSourceReadSuite would only add one test that tests for invalid options. It also sets:

spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
      "false")

I would prefer extending StateDataSourceReadSuite for now with plans to eventually extend RocksDBWithChangelogCheckpointStateDataSourceReaderSuite to include the tests with changeStartBatchId.

I added:

spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
      "true")

To the test suite I added.

override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
newStateStoreProvider().getClass.getName)
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
"true")
}

test("check unsupported modes with checkpoint v2") {
withTempDir { tmpDir =>
val inputData = MemoryStream[(Int, Long)]
val query = getStreamStreamJoinQuery(inputData)
testStream(query)(
StartStream(checkpointLocation = tmpDir.getCanonicalPath),
AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
ProcessAllAvailable(),
Execute { _ => Thread.sleep(2000) },
StopStream
)

// Verify reading snapshot throws error with checkpoint v2
val exc1 = intercept[StateDataSourceInvalidOptionValue] {
val stateSnapshotDf = spark.read.format("statestore")
.option("snapshotPartitionId", 2)
.option("snapshotStartBatchId", 0)
.option("joinSide", "left")
.load(tmpDir.getCanonicalPath)
stateSnapshotDf.collect()
}

checkError(exc1, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616",
Map(
"optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID,
"message" -> "Snapshot reading is currently not supported with checkpoint v2."))

// Verify reading change feed throws error with checkpoint v2
val exc2 = intercept[StateDataSourceInvalidOptionValue] {
val stateDf = spark.read.format("statestore")
.option(StateSourceOptions.READ_CHANGE_FEED, value = true)
.option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
.option(StateSourceOptions.CHANGE_END_BATCH_ID, 1)
.load(tmpDir.getAbsolutePath)
stateDf.collect()
}

checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616",
Map(
"optionName" -> StateSourceOptions.READ_CHANGE_FEED,
"message" -> "Read change feed is currently not supported with checkpoint v2."))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that creates the checkpoint with checkpointv2 and tries to read from it with the config set to checkpoint v1 and vice-versa?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this will error out when reading the commit log.

Should we do something like only checking this assertion if we are calling from the context of a streaming query? Or maybe just check for this assertion error?

cc @anishshri-db

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this. The reader should figure out the checkpoint format version it needs to use to read based on the commit log version right ? Which other config are we referring to here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It currently just trims the version number from the input then verifies it is equal to the expected version according to the sparkSession.conf "spark.sql.streaming.stateStore.checkpointFormatVersion".

There are a few different assertions in the methods we call that require the "spark.sql.streaming.stateStore.checkpointFormatVersion" to be the same version as in the CommitLog.

I see a few options:

  1. Remove these assertions completely
  2. Only check the assertion if we are in calling in a streaming query
  3. Require user to have "spark.sql.streaming.stateStore.checkpointFormatVersion" set to the correct version while using the state reader

What do you think?

}

abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions {

import testImplicits._
Expand Down