@@ -34,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException,
34
34
import org .apache .spark .annotation .{DeveloperApi , Since }
35
35
import org .apache .spark .api .python .{PythonEvalType , SimplePythonFunction }
36
36
import org .apache .spark .connect .proto
37
- import org .apache .spark .connect .proto .{CheckpointCommand , CreateResourceProfileCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
37
+ import org .apache .spark .connect .proto .{CheckpointCommand , ExecutePlanResponse , SqlCommand , StreamingForeachFunction , StreamingQueryCommand , StreamingQueryCommandResult , StreamingQueryInstanceId , StreamingQueryManagerCommand , StreamingQueryManagerCommandResult , WriteStreamOperationStart , WriteStreamOperationStartResult }
38
38
import org .apache .spark .connect .proto .ExecutePlanResponse .SqlCommandResult
39
39
import org .apache .spark .connect .proto .Parse .ParseFormat
40
40
import org .apache .spark .connect .proto .StreamingQueryManagerCommandResult .StreamingQueryInstance
@@ -62,6 +62,7 @@ import org.apache.spark.sql.classic.ClassicConversions._
62
62
import org .apache .spark .sql .connect .client .arrow .ArrowSerializer
63
63
import org .apache .spark .sql .connect .common .{DataTypeProtoConverter , ForeachWriterPacket , LiteralValueProtoConverter , StorageLevelProtoConverter , StreamingListenerPacket , UdfPacket }
64
64
import org .apache .spark .sql .connect .config .Connect .CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
65
+ import org .apache .spark .sql .connect .execution .command .{ConnectLeafRunnableCommand , CreateResourceProfileCommand }
65
66
import org .apache .spark .sql .connect .ml .MLHandler
66
67
import org .apache .spark .sql .connect .pipelines .PipelinesHandler
67
68
import org .apache .spark .sql .connect .plugin .SparkConnectPluginRegistry
@@ -2654,6 +2655,8 @@ class SparkConnectPlanner(
2654
2655
Some (transformMergeIntoTableCommand(command.getMergeIntoTableCommand))
2655
2656
case proto.Command .CommandTypeCase .CREATE_DATAFRAME_VIEW =>
2656
2657
Some (_ => transformCreateViewCommand(command.getCreateDataframeView))
2658
+ case proto.Command .CommandTypeCase .CREATE_RESOURCE_PROFILE_COMMAND =>
2659
+ Some (_ => transformCreateResourceProfileCommand(command.getCreateResourceProfileCommand))
2657
2660
case _ =>
2658
2661
None
2659
2662
}
@@ -2664,7 +2667,7 @@ class SparkConnectPlanner(
2664
2667
responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
2665
2668
val transformerOpt = transformCommand(command)
2666
2669
if (transformerOpt.isDefined) {
2667
- transformAndRunCommand(transformerOpt.get)
2670
+ transformAndRunCommand(transformerOpt.get, responseObserver )
2668
2671
return
2669
2672
}
2670
2673
command.getCommandTypeCase match {
@@ -2693,10 +2696,6 @@ class SparkConnectPlanner(
2693
2696
responseObserver)
2694
2697
case proto.Command .CommandTypeCase .GET_RESOURCES_COMMAND =>
2695
2698
handleGetResourcesCommand(responseObserver)
2696
- case proto.Command .CommandTypeCase .CREATE_RESOURCE_PROFILE_COMMAND =>
2697
- handleCreateResourceProfileCommand(
2698
- command.getCreateResourceProfileCommand,
2699
- responseObserver)
2700
2699
case proto.Command .CommandTypeCase .CHECKPOINT_COMMAND =>
2701
2700
handleCheckpointCommand(command.getCheckpointCommand, responseObserver)
2702
2701
case proto.Command .CommandTypeCase .REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
@@ -3153,11 +3152,22 @@ class SparkConnectPlanner(
3153
3152
}
3154
3153
}
3155
3154
3156
- private def transformAndRunCommand (transformer : QueryPlanningTracker => LogicalPlan ): Unit = {
3155
+ private def transformAndRunCommand (
3156
+ transformer : QueryPlanningTracker => LogicalPlan ,
3157
+ responseObserver : StreamObserver [ExecutePlanResponse ]): Unit = {
3157
3158
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
3158
3159
val qe = new QueryExecution (session, transformer(tracker), tracker)
3159
3160
qe.assertCommandExecuted()
3160
3161
executeHolder.eventsManager.postFinished()
3162
+ qe.logical match {
3163
+ case connectCommand : ConnectLeafRunnableCommand =>
3164
+ connectCommand.handleConnectResponse(
3165
+ responseObserver,
3166
+ sessionId,
3167
+ sessionHolder.serverSessionId)
3168
+ case _ =>
3169
+ // Do nothing
3170
+ }
3161
3171
}
3162
3172
3163
3173
/**
@@ -3674,9 +3684,8 @@ class SparkConnectPlanner(
3674
3684
.build())
3675
3685
}
3676
3686
3677
- private def handleCreateResourceProfileCommand (
3678
- createResourceProfileCommand : CreateResourceProfileCommand ,
3679
- responseObserver : StreamObserver [proto.ExecutePlanResponse ]): Unit = {
3687
+ private def transformCreateResourceProfileCommand (
3688
+ createResourceProfileCommand : proto.CreateResourceProfileCommand ): LogicalPlan = {
3680
3689
val rp = createResourceProfileCommand.getProfile
3681
3690
val ereqs = rp.getExecutorResourcesMap.asScala.map { case (name, res) =>
3682
3691
name -> new ExecutorResourceRequest (
@@ -3695,20 +3704,7 @@ class SparkConnectPlanner(
3695
3704
} else {
3696
3705
new ResourceProfile (ereqs, treqs)
3697
3706
}
3698
- session.sparkContext.resourceProfileManager.addResourceProfile(profile)
3699
-
3700
- executeHolder.eventsManager.postFinished()
3701
- responseObserver.onNext(
3702
- proto.ExecutePlanResponse
3703
- .newBuilder()
3704
- .setSessionId(sessionId)
3705
- .setServerSideSessionId(sessionHolder.serverSessionId)
3706
- .setCreateResourceProfileCommandResult(
3707
- proto.CreateResourceProfileCommandResult
3708
- .newBuilder()
3709
- .setProfileId(profile.id)
3710
- .build())
3711
- .build())
3707
+ CreateResourceProfileCommand (profile)
3712
3708
}
3713
3709
3714
3710
private def handleCheckpointCommand (
0 commit comments