@@ -22,27 +22,31 @@ import io.substrait.spark.expression._
22
22
import org .apache .spark .internal .Logging
23
23
import org .apache .spark .sql .SaveMode
24
24
import org .apache .spark .sql .catalyst .InternalRow
25
+ import org .apache .spark .sql .catalyst .analysis .ResolvedIdentifier
25
26
import org .apache .spark .sql .catalyst .catalog .{CatalogTable , HiveTableRelation }
26
27
import org .apache .spark .sql .catalyst .expressions ._
27
28
import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , Sum }
28
29
import org .apache .spark .sql .catalyst .plans ._
29
30
import org .apache .spark .sql .catalyst .plans .logical ._
30
31
import org .apache .spark .sql .execution .LogicalRDD
31
- import org .apache .spark .sql .execution .command .CreateDataSourceTableAsSelectCommand
32
+ import org .apache .spark .sql .execution .command .{ CreateDataSourceTableAsSelectCommand , CreateTableCommand , DropTableCommand }
32
33
import org .apache .spark .sql .execution .datasources .{FileFormat => DSFileFormat , HadoopFsRelation , InsertIntoHadoopFsRelationCommand , LogicalRelation , V1WriteCommand , WriteFiles }
33
34
import org .apache .spark .sql .execution .datasources .csv .CSVFileFormat
34
35
import org .apache .spark .sql .execution .datasources .orc .OrcFileFormat
35
36
import org .apache .spark .sql .execution .datasources .parquet .ParquetFileFormat
36
- import org .apache .spark .sql .execution .datasources .v2 .{DataSourceV2Relation , DataSourceV2ScanRelation }
37
+ import org .apache .spark .sql .execution .datasources .v2 .{DataSourceV2Relation , DataSourceV2ScanRelation , V2SessionCatalog }
38
+ import org .apache .spark .sql .hive .execution .{CreateHiveTableAsSelectCommand , InsertIntoHiveTable }
37
39
import org .apache .spark .sql .types .{NullType , StructType }
38
40
39
41
import io .substrait .`type` .{NamedStruct , Type }
40
42
import io .substrait .{proto , relation }
41
43
import io .substrait .debug .TreePrinter
42
44
import io .substrait .expression .{Expression => SExpression , ExpressionCreator }
45
+ import io .substrait .expression .Expression .StructLiteral
43
46
import io .substrait .extension .ExtensionCollector
44
47
import io .substrait .hint .Hint
45
48
import io .substrait .plan .Plan
49
+ import io .substrait .relation .AbstractDdlRel .{DdlObject , DdlOp }
46
50
import io .substrait .relation .AbstractWriteRel .{CreateMode , OutputMode , WriteOp }
47
51
import io .substrait .relation .RelProtoConverter
48
52
import io .substrait .relation .Set .SetOp
@@ -54,7 +58,7 @@ import io.substrait.utils.Util
54
58
import java .util
55
59
import java .util .{Collections , Optional }
56
60
57
- import scala .collection .JavaConverters .asJavaIterableConverter
61
+ import scala .collection .JavaConverters .{ asJavaIterableConverter , seqAsJavaList }
58
62
import scala .collection .mutable
59
63
import scala .collection .mutable .ArrayBuffer
60
64
@@ -75,9 +79,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
75
79
override def default (p : LogicalPlan ): relation.Rel = p match {
76
80
case c : CommandResult => visit(c.commandLogicalPlan)
77
81
case w : WriteFiles => visit(w.child)
78
- case c : V1WriteCommand => convertDataWritingCommand(c)
79
- case CreateDataSourceTableAsSelectCommand (table, mode, query, names) =>
80
- convertCTAS(table, mode, query, names)
82
+ case c : Command => convertCommand(c)
81
83
case p : LeafNode => convertReadOperator(p)
82
84
case s : SubqueryAlias => visit(s.child)
83
85
case v : View => visit(v.child)
@@ -566,6 +568,28 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
566
568
}
567
569
}
568
570
571
+ private def convertCommand (command : Command ): relation.Rel = command match {
572
+ case c : V1WriteCommand => convertDataWritingCommand(c)
573
+ case CreateDataSourceTableAsSelectCommand (table, mode, query, names) =>
574
+ convertCTAS(table, mode, query, names)
575
+ case CreateHiveTableAsSelectCommand (table, query, names, mode) =>
576
+ convertCTAS(table, mode, query, names)
577
+ case CreateTableCommand (table, _) =>
578
+ convertCreateTable(table.identifier.unquotedString.split(" \\ ." ), table.schema)
579
+ case DropTableCommand (tableName, ifExists, _, _) =>
580
+ convertDropTable(tableName.unquotedString.split(" \\ ." ), ifExists)
581
+ case CreateTable (ResolvedIdentifier (c : V2SessionCatalog , id), tableSchema, _, _, _)
582
+ if id.namespace().length > 0 =>
583
+ val names = Seq (c.name(), id.namespace()(0 ), id.name())
584
+ convertCreateTable(names, tableSchema)
585
+ case DropTable (ResolvedIdentifier (c : V2SessionCatalog , id), ifExists, _)
586
+ if id.namespace().length > 0 =>
587
+ val names = Seq (c.name(), id.namespace()(0 ), id.name())
588
+ convertDropTable(names, ifExists)
589
+ case _ =>
590
+ throw new UnsupportedOperationException (s " Unable to convert command: $command" )
591
+ }
592
+
569
593
private def convertDataWritingCommand (command : V1WriteCommand ): relation.AbstractWriteRel =
570
594
command match {
571
595
case InsertIntoHadoopFsRelationCommand (
@@ -600,6 +624,16 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
600
624
.tableSchema(outputSchema(child.output, outputColumnNames))
601
625
.detail(FileHolder (file))
602
626
.build()
627
+ case InsertIntoHiveTable (table, _, child, overwrite, _, outputColumnNames, _, _, _, _, _) =>
628
+ relation.NamedWrite
629
+ .builder()
630
+ .input(visit(child))
631
+ .operation(WriteOp .INSERT )
632
+ .outputMode(OutputMode .UNSPECIFIED )
633
+ .createMode(if (overwrite) CreateMode .REPLACE_IF_EXISTS else CreateMode .ERROR_IF_EXISTS )
634
+ .names(seqAsJavaList(table.identifier.unquotedString.split(" \\ ." ).toList))
635
+ .tableSchema(outputSchema(child.output, outputColumnNames))
636
+ .build()
603
637
case _ =>
604
638
throw new UnsupportedOperationException (s " Unable to convert command: ${command.getClass}" )
605
639
}
@@ -619,6 +653,29 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
619
653
.tableSchema(outputSchema(query.output, outputColumnNames))
620
654
.build()
621
655
656
+ private def convertCreateTable (names : Seq [String ], schema : StructType ): relation.NamedDdl = {
657
+ relation.NamedDdl
658
+ .builder()
659
+ .operation(DdlOp .CREATE )
660
+ .`object`(DdlObject .TABLE )
661
+ .names(seqAsJavaList(names))
662
+ .tableSchema(ToSubstraitType .toNamedStruct(schema))
663
+ .tableDefaults(StructLiteral .builder.nullable(true ).build())
664
+ .build()
665
+ }
666
+
667
+ private def convertDropTable (names : Seq [String ], ifExists : Boolean ): relation.NamedDdl = {
668
+ relation.NamedDdl
669
+ .builder()
670
+ .operation(if (ifExists) DdlOp .DROP_IF_EXIST else DdlOp .DROP )
671
+ .`object`(DdlObject .TABLE )
672
+ .names(seqAsJavaList(names))
673
+ .tableSchema(
674
+ NamedStruct .builder().struct(Type .Struct .builder().nullable(true ).build()).build())
675
+ .tableDefaults(StructLiteral .builder.nullable(true ).build())
676
+ .build()
677
+ }
678
+
622
679
private def createMode (mode : SaveMode ): CreateMode = mode match {
623
680
case SaveMode .Append => CreateMode .APPEND_IF_EXISTS
624
681
case SaveMode .Overwrite => CreateMode .REPLACE_IF_EXISTS
0 commit comments