Skip to content

[SPARK-52689][SQL] Send DML Metrics to V2Write #51377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.spark.annotation.Evolving;

import java.util.Map;

/**
* An interface that defines how to write the data to data source for batch processing.
* <p>
Expand Down Expand Up @@ -88,6 +90,49 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
*/
void commit(WriterCommitMessage[] messages);

/**
* Commits this writing job with a list of commit messages and operation metrics.
* <p>
* If this method fails (by throwing an exception), this writing job is considered to to have been
* failed, and {@link #abort(WriterCommitMessage[])} would be called. The state of the destination
* is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it.
* <p>
* Note that speculative execution may cause multiple tasks to run for a partition. By default,
* Spark uses the commit coordinator to allow at most one task to commit. Implementations can
* disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple
* tasks may have committed successfully and one successful commit message per task will be
* passed to this commit method. The remaining commit messages are ignored by Spark.
* <p>
* @param messages a list of commit messages from successful data writers, produced by
* {@link DataWriter#commit()}.
* @param metrics a map of operation metrics collected from the query producing write.
* The keys will be prefixed by operation type, eg `merge`.
* <p>
* Currently supported metrics are:
* <ul>
* <li>Operation Type = `merge`
* <ul>
* <li>`numTargetRowsCopied`: number of target rows copied unmodified because
* they did not match any action</li>
* <li>`numTargetRowsDeleted`: number of target rows deleted</li>
* <li>`numTargetRowsUpdated`: number of target rows updated</li>
* <li>`numTargetRowsInserted`: number of target rows inserted</li>
* <li>`numTargetRowsMatchedUpdated`: number of target rows updated by a
* matched clause</li>
* <li>`numTargetRowsMatchedDeleted`: number of target rows deleted by a
* matched clause</li>
* <li>`numTargetRowsNotMatchedBySourceUpdated`: number of target rows
* updated by a not matched by source clause</li>
* <li>`numTargetRowsNotMatchedBySourceDeleted`: number of target rows
* deleted by a not matched by source clause</li>
* </ul>
* </li>
* </ul>
*/
default void commit(WriterCommitMessage[] messages, Map<String, Long> metrics) {
commit(messages);
}

/**
* Aborts this writing job because some data writers are failed and keep failing when retry,
* or the Spark job fails with some unknown reasons,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util
import java.util.OptionalLong

import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

import com.google.common.base.Objects
Expand Down Expand Up @@ -152,6 +153,8 @@ abstract class InMemoryBaseTable(
// The key `Seq[Any]` is the partition values, value is a set of splits, each with a set of rows.
val dataMap: mutable.Map[Seq[Any], Seq[BufferedRows]] = mutable.Map.empty

val commits: ListBuffer[Commit] = ListBuffer[Commit]()

def data: Array[BufferedRows] = dataMap.values.flatten.toArray

def rows: Seq[InternalRow] = dataMap.values.flatten.flatMap(_.rows).toSeq
Expand Down Expand Up @@ -616,6 +619,9 @@ abstract class InMemoryBaseTable(
}

protected abstract class TestBatchWrite extends BatchWrite {

var commitProperties: mutable.Map[String, String] = mutable.Map.empty[String, String]

override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
BufferedRowsWriterFactory
}
Expand All @@ -624,8 +630,11 @@ abstract class InMemoryBaseTable(
}

class Append(val info: LogicalWriteInfo) extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

Expand All @@ -634,13 +643,17 @@ abstract class InMemoryBaseTable(
val newData = messages.map(_.asInstanceOf[BufferedRows])
dataMap --= newData.flatMap(_.rows.map(getKey))
withData(newData)
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear()
withData(messages.map(_.asInstanceOf[BufferedRows]))
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

Expand Down Expand Up @@ -882,6 +895,8 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
override def value(): Long = value
}

case class Commit(id: Long, properties: Map[String, String])

sealed trait Operation
case object Write extends Operation
case object Delete extends Operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.connector.catalog

import java.util
import java.{lang, util}
import java.time.Instant

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
Expand Down Expand Up @@ -111,7 +114,21 @@ class InMemoryRowLevelOperationTable(
override def description(): String = "InMemoryPartitionReplaceOperation"
}

private case class PartitionBasedReplaceData(scan: InMemoryBatchScan) extends TestBatchWrite {
abstract class RowLevelOperationBatchWrite extends TestBatchWrite {

override def commit(messages: Array[WriterCommitMessage],
metrics: util.Map[String, lang.Long]): Unit = {
metrics.asScala.map {
case (key, value) => commitProperties += key -> String.valueOf(value)
}
commit(messages)
commits += Commit(Instant.now().toEpochMilli, commitProperties.toMap)
commitProperties.clear()
}
}

private case class PartitionBasedReplaceData(scan: InMemoryBatchScan)
extends RowLevelOperationBatchWrite {

override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
val newData = messages.map(_.asInstanceOf[BufferedRows])
Expand Down Expand Up @@ -165,7 +182,7 @@ class InMemoryRowLevelOperationTable(
}
}

private object TestDeltaBatchWrite extends DeltaBatchWrite {
private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite with DeltaBatchWrite{
override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = {
DeltaBufferedRowsWriterFactory
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.datasources.v2

import java.lang
import java.util

import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
Expand All @@ -34,6 +37,7 @@ import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, PhysicalWriteInfoImpl, Write, WriterCommitMessage}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{LongAccumulator, Utils}
Expand Down Expand Up @@ -398,7 +402,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
/**
* The base physical plan for writing data into data source v2.
*/
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
trait V2TableWriteExec extends V2CommandExec with UnaryExecNode with AdaptiveSparkPlanHelper {
def query: SparkPlan
def writingTask: WritingSparkTask[_] = DataWritingSparkTask

Expand Down Expand Up @@ -451,8 +455,9 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {
}
)

val operationMetrics = getOperationMetrics(query)
logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} is committing.")
batchWrite.commit(messages)
batchWrite.commit(messages, operationMetrics)
logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} committed.")
commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value))
} catch {
Expand All @@ -474,6 +479,12 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {

Nil
}

private def getOperationMetrics(query: SparkPlan): util.Map[String, lang.Long] = {
collectFirst(query) { case m: MergeRowsExec => m }.map{ n =>
n.metrics.map { case (name, metric) => s"merge.$name" -> lang.Long.valueOf(metric.value) }
}.getOrElse(Map.empty[String, lang.Long]).asJava
}
}

trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableInfo}
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, InMemoryTable, TableInfo}
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -1811,6 +1811,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(1, 1000, "hr"), // updated
Row(2, 200, "software"),
Row(3, 300, "hr")))

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "2"))
assert(commitProps("merge.numTargetRowsInserted") === "0")
assert(commitProps("merge.numTargetRowsUpdated") === "1")
assert(commitProps("merge.numTargetRowsDeleted") === "0")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "1")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "0")
}
}

Expand Down Expand Up @@ -1856,6 +1867,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(2, 200, "software"),
Row(3, 300, "hr"),
Row(5, 400, "executive"))) // inserted

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === "0")
assert(commitProps("merge.numTargetRowsInserted") === "1")
assert(commitProps("merge.numTargetRowsUpdated") === "0")
assert(commitProps("merge.numTargetRowsDeleted") === "0")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "0")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "0")
}
}

Expand Down Expand Up @@ -1883,7 +1905,6 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
|""".stripMargin
}


assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3)
assertMetric(mergeExec, "numTargetRowsInserted", 0)
assertMetric(mergeExec, "numTargetRowsUpdated", 2)
Expand All @@ -1901,6 +1922,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(3, 300, "hr"),
Row(4, 400, "marketing"),
Row(5, -1, "executive"))) // updated

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "3"))
assert(commitProps("merge.numTargetRowsInserted") === "0")
assert(commitProps("merge.numTargetRowsUpdated") === "2")
assert(commitProps("merge.numTargetRowsDeleted") === "0")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "1")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "1")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "0")
}
}

Expand Down Expand Up @@ -1947,6 +1979,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(4, 400, "marketing"))
// Row(5, 500, "executive") deleted
)

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "3"))
assert(commitProps("merge.numTargetRowsInserted") === "0")
assert(commitProps("merge.numTargetRowsUpdated") === "0")
assert(commitProps("merge.numTargetRowsDeleted") === "2")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "0")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "1")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "1")
}
}

Expand Down Expand Up @@ -1994,6 +2037,17 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(4, 400, "marketing"),
Row(5, -1, "executive"), // updated
Row(6, -1, "dummy"))) // inserted

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "3"))
assert(commitProps("merge.numTargetRowsInserted") === "1")
assert(commitProps("merge.numTargetRowsUpdated") === "2")
assert(commitProps("merge.numTargetRowsDeleted") === "0")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "1")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "1")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "0")
}
}

Expand Down Expand Up @@ -2032,7 +2086,6 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceUpdated", 0)
assertMetric(mergeExec, "numTargetRowsNotMatchedBySourceDeleted", 1)


checkAnswer(
sql(s"SELECT * FROM $tableNameAsString"),
Seq(
Expand All @@ -2042,6 +2095,62 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
Row(4, 400, "marketing"),
// Row(5, 500, "executive") deleted
Row(6, -1, "dummy"))) // inserted

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "3"))
assert(commitProps("merge.numTargetRowsInserted") === "1")
assert(commitProps("merge.numTargetRowsUpdated") === "0")
assert(commitProps("merge.numTargetRowsDeleted") === "2")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "0")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "1")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "1")
}
}

test("SPARK-52689: V2 write metrics for merge") {
Seq("true", "false").foreach { aqeEnabled: String =>
withTempView("source") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
"""{ "pk": 1, "salary": 100, "dep": "hr" }
|{ "pk": 2, "salary": 200, "dep": "software" }
|{ "pk": 3, "salary": 300, "dep": "hr" }
|{ "pk": 4, "salary": 400, "dep": "marketing" }
|{ "pk": 5, "salary": 500, "dep": "executive" }
|""".stripMargin)

val sourceDF = Seq(1, 2, 6, 10).toDF("pk")
sourceDF.createOrReplaceTempView("source")

sql(
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED AND salary < 200 THEN
| DELETE
|WHEN NOT MATCHED AND s.pk < 10 THEN
| INSERT (pk, salary, dep) VALUES (s.pk, -1, "dummy")
|WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
| DELETE
|""".stripMargin
)

val table = catalog.loadTable(ident)
val commitProps = table.asInstanceOf[InMemoryTable].commits.last.properties
assert(commitProps("merge.numTargetRowsCopied") === (if (deltaMerge) "0" else "3"))
assert(commitProps("merge.numTargetRowsInserted") === "1")
assert(commitProps("merge.numTargetRowsUpdated") === "0")
assert(commitProps("merge.numTargetRowsDeleted") === "2")
assert(commitProps("merge.numTargetRowsMatchedUpdated") === "0")
assert(commitProps("merge.numTargetRowsMatchedDeleted") === "1")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceUpdated") === "0")
assert(commitProps("merge.numTargetRowsNotMatchedBySourceDeleted") === "1")

sql(s"DROP TABLE $tableNameAsString")
}
}
}
}

Expand Down