Skip to content

Commit 9f86684

Browse files
committed
[SPARK-53292] Make CreateResourceProfileCommand in SparkConnectPlanner side effect free
1 parent 967f2b6 commit 9f86684

File tree

4 files changed

+119
-24
lines changed

4 files changed

+119
-24
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.classic.{DataFrame, Dataset}
3232
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3333
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
3434
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
35+
import org.apache.spark.sql.connect.execution.command.ConnectLeafRunnableCommand
3536
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3637
import org.apache.spark.sql.connect.service.ExecuteHolder
3738
import org.apache.spark.sql.connect.utils.MetricGenerator
@@ -90,6 +91,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
9091
shuffleCleanupMode = shuffleCleanupMode)
9192
qe.assertCommandExecuted()
9293
executeHolder.eventsManager.postFinished()
94+
qe.logical match {
95+
case connectCommand: ConnectLeafRunnableCommand =>
96+
connectCommand.handleConnectResponse(
97+
responseObserver,
98+
sessionHolder.sessionId,
99+
sessionHolder.serverSessionId)
100+
case _ =>
101+
// Do nothing
102+
}
93103
case None =>
94104
planner.process(command, responseObserver)
95105
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.execution.command
18+
19+
import io.grpc.stub.StreamObserver
20+
21+
import org.apache.spark.connect.proto
22+
import org.apache.spark.sql.{classic, Row, SparkSession}
23+
import org.apache.spark.sql.execution.command.LeafRunnableCommand
24+
25+
trait ConnectLeafRunnableCommand extends LeafRunnableCommand {
26+
27+
final override def run(sparkSession: SparkSession): Seq[Row] = {
28+
run(sparkSession.asInstanceOf[classic.SparkSession])
29+
}
30+
31+
def run(sparkSession: classic.SparkSession): Seq[Row]
32+
33+
def handleConnectResponse(
34+
responseObserver: StreamObserver[proto.ExecutePlanResponse],
35+
sessionId: String,
36+
serverSessionId: String): Unit = {
37+
// Default implementation does nothing.
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.execution.command
18+
19+
import io.grpc.stub.StreamObserver
20+
21+
import org.apache.spark.connect.proto
22+
import org.apache.spark.connect.proto.ExecutePlanResponse
23+
import org.apache.spark.resource.ResourceProfile
24+
import org.apache.spark.sql.Row
25+
import org.apache.spark.sql.classic.SparkSession
26+
27+
case class CreateResourceProfileCommand(rp: ResourceProfile) extends ConnectLeafRunnableCommand {
28+
29+
override def run(sparkSession: SparkSession): Seq[Row] = {
30+
sparkSession.sparkContext.resourceProfileManager.addResourceProfile(rp)
31+
Seq.empty
32+
}
33+
34+
override def handleConnectResponse(
35+
responseObserver: StreamObserver[ExecutePlanResponse],
36+
sessionId: String,
37+
serverSessionId: String): Unit = {
38+
responseObserver.onNext(
39+
proto.ExecutePlanResponse
40+
.newBuilder()
41+
.setSessionId(sessionId)
42+
.setServerSideSessionId(serverSessionId)
43+
.setCreateResourceProfileCommandResult(
44+
proto.CreateResourceProfileCommandResult
45+
.newBuilder()
46+
.setProfileId(rp.id)
47+
.build())
48+
.build())
49+
}
50+
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
3434
import org.apache.spark.annotation.{DeveloperApi, Since}
3535
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
3636
import org.apache.spark.connect.proto
37-
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
37+
import org.apache.spark.connect.proto.{CheckpointCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
3838
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
3939
import org.apache.spark.connect.proto.Parse.ParseFormat
4040
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -62,6 +62,7 @@ import org.apache.spark.sql.classic.ClassicConversions._
6262
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
6363
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
6464
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
65+
import org.apache.spark.sql.connect.execution.command.{ConnectLeafRunnableCommand, CreateResourceProfileCommand}
6566
import org.apache.spark.sql.connect.ml.MLHandler
6667
import org.apache.spark.sql.connect.pipelines.PipelinesHandler
6768
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
@@ -2654,6 +2655,8 @@ class SparkConnectPlanner(
26542655
Some(transformMergeIntoTableCommand(command.getMergeIntoTableCommand))
26552656
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
26562657
Some(_ => transformCreateViewCommand(command.getCreateDataframeView))
2658+
case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND =>
2659+
Some(_ => transformCreateResourceProfileCommand(command.getCreateResourceProfileCommand))
26572660
case _ =>
26582661
None
26592662
}
@@ -2664,7 +2667,7 @@ class SparkConnectPlanner(
26642667
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
26652668
val transformerOpt = transformCommand(command)
26662669
if (transformerOpt.isDefined) {
2667-
transformAndRunCommand(transformerOpt.get)
2670+
transformAndRunCommand(transformerOpt.get, responseObserver)
26682671
return
26692672
}
26702673
command.getCommandTypeCase match {
@@ -2693,10 +2696,6 @@ class SparkConnectPlanner(
26932696
responseObserver)
26942697
case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
26952698
handleGetResourcesCommand(responseObserver)
2696-
case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND =>
2697-
handleCreateResourceProfileCommand(
2698-
command.getCreateResourceProfileCommand,
2699-
responseObserver)
27002699
case proto.Command.CommandTypeCase.CHECKPOINT_COMMAND =>
27012700
handleCheckpointCommand(command.getCheckpointCommand, responseObserver)
27022701
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
@@ -3153,11 +3152,22 @@ class SparkConnectPlanner(
31533152
}
31543153
}
31553154

3156-
private def transformAndRunCommand(transformer: QueryPlanningTracker => LogicalPlan): Unit = {
3155+
private def transformAndRunCommand(
3156+
transformer: QueryPlanningTracker => LogicalPlan,
3157+
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
31573158
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
31583159
val qe = new QueryExecution(session, transformer(tracker), tracker)
31593160
qe.assertCommandExecuted()
31603161
executeHolder.eventsManager.postFinished()
3162+
qe.logical match {
3163+
case connectCommand: ConnectLeafRunnableCommand =>
3164+
connectCommand.handleConnectResponse(
3165+
responseObserver,
3166+
sessionId,
3167+
sessionHolder.serverSessionId)
3168+
case _ =>
3169+
// Do nothing
3170+
}
31613171
}
31623172

31633173
/**
@@ -3674,9 +3684,8 @@ class SparkConnectPlanner(
36743684
.build())
36753685
}
36763686

3677-
private def handleCreateResourceProfileCommand(
3678-
createResourceProfileCommand: CreateResourceProfileCommand,
3679-
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
3687+
private def transformCreateResourceProfileCommand(
3688+
createResourceProfileCommand: proto.CreateResourceProfileCommand): LogicalPlan = {
36803689
val rp = createResourceProfileCommand.getProfile
36813690
val ereqs = rp.getExecutorResourcesMap.asScala.map { case (name, res) =>
36823691
name -> new ExecutorResourceRequest(
@@ -3695,20 +3704,7 @@ class SparkConnectPlanner(
36953704
} else {
36963705
new ResourceProfile(ereqs, treqs)
36973706
}
3698-
session.sparkContext.resourceProfileManager.addResourceProfile(profile)
3699-
3700-
executeHolder.eventsManager.postFinished()
3701-
responseObserver.onNext(
3702-
proto.ExecutePlanResponse
3703-
.newBuilder()
3704-
.setSessionId(sessionId)
3705-
.setServerSideSessionId(sessionHolder.serverSessionId)
3706-
.setCreateResourceProfileCommandResult(
3707-
proto.CreateResourceProfileCommandResult
3708-
.newBuilder()
3709-
.setProfileId(profile.id)
3710-
.build())
3711-
.build())
3707+
CreateResourceProfileCommand(profile)
37123708
}
37133709

37143710
private def handleCheckpointCommand(

0 commit comments

Comments
 (0)