Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ case class StateSourceOptions(
readChangeFeedOptions: Option[ReadChangeFeedOptions],
stateVarName: Option[String],
readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean) {
flattenCollectionTypes: Boolean,
operatorStateUniqueIds: Option[Array[Array[String]]] = None) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
Expand Down Expand Up @@ -567,10 +568,38 @@ object StateSourceOptions extends DataSourceOptions {
}
}


val startBatchId = if (fromSnapshotOptions.isDefined) {
fromSnapshotOptions.get.snapshotStartBatchId
} else if (readChangeFeedOptions.isDefined) {
readChangeFeedOptions.get.changeStartBatchId
} else {
batchId.get
}

val operatorStateUniqueIds = getOperatorStateUniqueIds(
sparkSession,
startBatchId,
operatorId,
resolvedCpLocation)

if (operatorStateUniqueIds.isDefined) {
if (fromSnapshotOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID,
"Snapshot reading is currently not supported with checkpoint v2.")
}
if (readChangeFeedOptions.isDefined) {
throw StateDataSourceErrors.invalidOptionValue(
READ_CHANGE_FEED,
"Read change feed is currently not supported with checkpoint v2.")
}
}

StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
stateVarName, readRegisteredTimers, flattenCollectionTypes)
stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds)
}

private def resolvedCheckpointLocation(
Expand All @@ -589,6 +618,26 @@ object StateSourceOptions extends DataSourceOptions {
}
}

private def getOperatorStateUniqueIds(
session: SparkSession,
batchId: Long,
operatorId: Long,
checkpointLocation: String): Option[Array[Array[String]]] = {
val commitLog = new StreamingQueryCheckpointMetadata(session, checkpointLocation).commitLog
val commitMetadata = commitLog.get(batchId) match {
case Some(commitMetadata) => commitMetadata
case None => throw StateDataSourceErrors.committedBatchUnavailable(checkpointLocation)
}

val operatorStateUniqueIds = if (commitMetadata.stateUniqueIds.isDefined) {
Some(commitMetadata.stateUniqueIds.get(operatorId))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This can be written in a more scala way, without if-else. Maybe with stateUniqueIds.map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to commitMetadata.stateUniqueIds.flatMap(_.get(operatorId))

} else {
None
}

operatorStateUniqueIds
}

// Modifies options due to external data. Returns modified options.
// If this is a join operator specifying a store name using state format v3,
// we need to modify the options.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,24 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
import org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
import org.apache.spark.sql.types.{NullType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, SerializableConfiguration}

/**
* Constants for store names used in Stream-Stream joins.
*/
object StatePartitionReaderStoreNames {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define these names here? These are join specific and shouldn't live here. I think they should already be defined in the join code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed these and refactored a bit. I added more detail in the other comment in this file.

val LEFT_KEY_TO_NUM_VALUES_STORE = "left-keyToNumValues"
val LEFT_KEY_WITH_INDEX_TO_VALUE_STORE = "left-keyWithIndexToValue"
val RIGHT_KEY_TO_NUM_VALUES_STORE = "right-keyToNumValues"
val RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE = "right-keyWithIndexToValue"
}

/**
* An implementation of [[PartitionReaderFactory]] for State data source. This is used to support
* general read from a state store instance, rather than specific to the operator.
Expand Down Expand Up @@ -95,6 +106,31 @@ abstract class StatePartitionReaderBase(
schema, "value").asInstanceOf[StructType]
}

protected val getStoreUniqueId : Option[String] = {
val partitionStateUniqueIds =
partition.sourceOptions.operatorStateUniqueIds.map(_(partition.partition))
if (partition.sourceOptions.storeName == StateStoreId.DEFAULT_STORE_NAME) {
partitionStateUniqueIds.map(_.head)
} else {
val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
useColumnFamiliesForJoins = false)

partition.sourceOptions.storeName match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to do this here. This can be done within the join call above and it will just return the id you need for the storeName, instead of returning the entire stateStoreCheckpointIds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a new method getStateStoreCheckpointId in SymmetricHashJoinStateManager which maps (storeName -> correct checkpoint id) done in one function call. Let me know if this makes more sense.

case StatePartitionReaderStoreNames.LEFT_KEY_TO_NUM_VALUES_STORE =>
stateStoreCheckpointIds.left.keyToNumValues
case StatePartitionReaderStoreNames.LEFT_KEY_WITH_INDEX_TO_VALUE_STORE =>
stateStoreCheckpointIds.left.valueToNumKeys
case StatePartitionReaderStoreNames.RIGHT_KEY_TO_NUM_VALUES_STORE =>
stateStoreCheckpointIds.right.keyToNumValues
case StatePartitionReaderStoreNames.RIGHT_KEY_WITH_INDEX_TO_VALUE_STORE =>
stateStoreCheckpointIds.right.valueToNumKeys
case _ => None
}
}
}

protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
Expand All @@ -113,7 +149,9 @@ abstract class StatePartitionReaderBase(
val isInternal = partition.sourceOptions.readRegisteredTimers

if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
val store = provider.getStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId)
require(stateStoreColFamilySchemaOpt.isDefined)
val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
Expand Down Expand Up @@ -171,7 +209,11 @@ class StatePartitionReader(

private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)
case None =>
provider.getReadStore(
partition.sourceOptions.batchId + 1,
getStoreUniqueId
)

case Some(fromSnapshotOptions) =>
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,28 @@ class StreamStreamJoinStatePartitionReader(
throw StateDataSourceErrors.internalError("Unexpected join side for stream-stream read!")
}

private val usesVirtualColumnFamilies = StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)

private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partition.partition,
partition.sourceOptions.operatorStateUniqueIds,
usesVirtualColumnFamilies)

private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.keyToNumValues
} else {
stateStoreCheckpointIds.right.keyToNumValues
}

private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) {
stateStoreCheckpointIds.left.valueToNumKeys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is valueToNumKeys here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was a typo in the method I am calling. I refactored by changing valueToNumKeys -> keyWithIndexToValue. This now follows the current store names we use.

} else {
stateStoreCheckpointIds.right.valueToNumKeys
}

/*
* This is to handle the difference of schema across state format versions. The major difference
* is whether we have added new field(s) in addition to the fields from input schema.
Expand All @@ -85,10 +107,7 @@ class StreamStreamJoinStatePartitionReader(
// column from the value schema to get the actual fields.
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
// If checkpoint is using one store and virtual column families, version is 3
if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)) {
if (usesVirtualColumnFamilies) {
(valueSchema.dropRight(1), 3)
} else {
(valueSchema.dropRight(1), 2)
Expand Down Expand Up @@ -130,8 +149,8 @@ class StreamStreamJoinStatePartitionReader(
storeConf = storeConf,
hadoopConf = hadoopConf.value,
partitionId = partition.partition,
keyToNumValuesStateStoreCkptId = None,
keyWithIndexToValueStateStoreCkptId = None,
keyToNumValuesStateStoreCkptId = keyToNumValuesStateStoreCkptId,
keyWithIndexToValueStateStoreCkptId = keyWithIndexToValueStateStoreCkptId,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ case class StreamingSymmetricHashJoinExec(

assert(stateInfo.isDefined, "State info not defined")
val checkpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds(
partitionId, stateInfo.get, useVirtualColumnFamilies)
partitionId, stateInfo.get.stateStoreCkptIds, useVirtualColumnFamilies)

val inputSchema = left.output ++ right.output
val postJoinFilter =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1167,26 +1167,26 @@ object SymmetricHashJoinStateManager {
*/
def getStateStoreCheckpointIds(
partitionId: Int,
stateInfo: StatefulOperatorStateInfo,
stateStoreCkptIds: Option[Array[Array[String]]],
useColumnFamiliesForJoins: Boolean): JoinStateStoreCheckpointId = {
if (useColumnFamiliesForJoins) {
val ckpt = stateInfo.stateStoreCkptIds.map(_(partitionId)).map(_.head)
val ckpt = stateStoreCkptIds.map(_(partitionId)).map(_.head)
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt),
right = JoinerStateStoreCheckpointId(keyToNumValues = ckpt, valueToNumKeys = ckpt)
)
} else {
val stateStoreCkptIds = stateInfo.stateStoreCkptIds
val stateStoreCkptIdsOpt = stateStoreCkptIds
.map(_(partitionId))
.map(_.map(Option(_)))
.getOrElse(Array.fill[Option[String]](4)(None))
JoinStateStoreCheckpointId(
left = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(0),
valueToNumKeys = stateStoreCkptIds(1)),
keyToNumValues = stateStoreCkptIdsOpt(0),
valueToNumKeys = stateStoreCkptIdsOpt(1)),
right = JoinerStateStoreCheckpointId(
keyToNumValues = stateStoreCkptIds(2),
valueToNumKeys = stateStoreCkptIds(3)))
keyToNumValues = stateStoreCkptIdsOpt(2),
valueToNumKeys = stateStoreCkptIdsOpt(3)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ private[sql] class RocksDBStateStoreProvider

rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
stateStoreCkptId = uniqueId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove the conf check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the behavior was a bit confusing where uniqueId could be Some("") but would not be used to get the underlying store.

This also would need to be removed in the future if we wanted to enable reading checkpoint v2 stores when enableStateStoreCheckpointIds = false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this as a separate change though @dylanwong250 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I removed it. We may have to add it back depending on this comment https://github.com/apache/spark/pull/52047/files#r2283491163.

readOnly = readOnly)

// Create or reuse store instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,64 @@ StateDataSourceReadSuite {
}
}

class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceReadSuite {
override protected def newStateStoreProvider(): RocksDBStateStoreProvider =
new RocksDBStateStoreProvider

import testImplicits._

override def beforeAll(): Unit = {
super.beforeAll()
spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
newStateStoreProvider().getClass.getName)
}

test("check unsupported modes with checkpoint v2") {
withTempDir { tmpDir =>
val inputData = MemoryStream[(Int, Long)]
val query = getStreamStreamJoinQuery(inputData)
testStream(query)(
StartStream(checkpointLocation = tmpDir.getCanonicalPath),
AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
ProcessAllAvailable(),
Execute { _ => Thread.sleep(2000) },
StopStream
)

// Verify reading snapshot throws error with checkpoint v2
val exc1 = intercept[StateDataSourceInvalidOptionValue] {
val stateSnapshotDf = spark.read.format("statestore")
.option("snapshotPartitionId", 2)
.option("snapshotStartBatchId", 0)
.option("joinSide", "left")
.load(tmpDir.getCanonicalPath)
stateSnapshotDf.collect()
}

checkError(exc1, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616",
Map(
"optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID,
"message" -> "Snapshot reading is currently not supported with checkpoint v2."))

// Verify reading change feed throws error with checkpoint v2
val exc2 = intercept[StateDataSourceInvalidOptionValue] {
val stateDf = spark.read.format("statestore")
.option(StateSourceOptions.READ_CHANGE_FEED, value = true)
.option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
.option(StateSourceOptions.CHANGE_END_BATCH_ID, 1)
.load(tmpDir.getAbsolutePath)
stateDf.collect()
}

checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616",
Map(
"optionName" -> StateSourceOptions.READ_CHANGE_FEED,
"message" -> "Read change feed is currently not supported with checkpoint v2."))
}
}
}

abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Assertions {

import testImplicits._
Expand Down