Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
metastore_db
spark-warehouse
/src/test/resources/write-a.csv
derby.log
6 changes: 5 additions & 1 deletion spark/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ dependencies {
implementation(libs.scala.library)
api(libs.spark.core)
api(libs.spark.sql)
implementation(libs.spark.hive)
implementation(libs.spark.catalyst)
implementation(libs.slf4j.api)

Expand Down Expand Up @@ -148,6 +149,9 @@ tasks {
test {
dependsOn(":core:shadowJar")
useJUnitPlatform { includeEngines("scalatest") }
jvmArgs("--add-exports=java.base/sun.nio.ch=ALL-UNNAMED")
jvmArgs(
"--add-exports=java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens=java.base/java.net=ALL-UNNAMED",
)
}
}
109 changes: 79 additions & 30 deletions spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,31 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOute
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateTableCommand, DataWritingCommand, DropTableCommand, LeafRunnableCommand}
import org.apache.spark.sql.execution.datasources.{FileFormat => SparkFileFormat, HadoopFsRelation, InMemoryFileIndex, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1Writes}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType}

import io.substrait.`type`.{NamedStruct, StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.plan.Plan
import io.substrait.relation
import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedWrite}
import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedDdl, NamedWrite}
import io.substrait.relation.AbstractDdlRel.{DdlObject, DdlOp}
import io.substrait.relation.AbstractWriteRel.{CreateMode, WriteOp}
import io.substrait.relation.Expand.{ConsistentField, SwitchingField}
import io.substrait.relation.Set.SetOp
import io.substrait.relation.files.FileFormat
import io.substrait.util.EmptyVisitationContext
import org.apache.hadoop.fs.Path

import java.net.URI

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -437,35 +441,44 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())

override def visit(write: NamedWrite, context: EmptyVisitationContext): LogicalPlan = {
val child = write.getInput.accept(this, context)

val (table, database, catalog) = write.getNames.asScala match {
case Seq(table) => (table, None, None)
case Seq(database, table) => (table, Some(database), None)
case Seq(catalog, database, table) => (table, Some(database), Some(catalog))
case names =>
throw new UnsupportedOperationException(
s"NamedWrite requires up to three names ([[catalog,] database,] table): $names")
val table = catalogTable(write.getNames.asScala)
val isHive = spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION.key) match {
case "hive" => true
case _ => false
}
val id = TableIdentifier(table, database, catalog)
val catalogTable = CatalogTable(
id,
CatalogTableType.MANAGED,
CatalogStorageFormat.empty,
new StructType(),
Some("parquet")
)
write.getOperation match {
case WriteOp.CTAS =>
withChild(child) {
CreateDataSourceTableAsSelectCommand(
catalogTable,
saveMode(write.getCreateMode),
if (isHive) {
CreateHiveTableAsSelectCommand(
table,
child,
write.getTableSchema.names().asScala,
saveMode(write.getCreateMode)
)
} else {
CreateDataSourceTableAsSelectCommand(
table,
saveMode(write.getCreateMode),
child,
write.getTableSchema.names().asScala
)
}
}
case WriteOp.INSERT =>
withChild(child) {
InsertIntoHiveTable(
table,
Map.empty,
child,
write.getCreateMode == CreateMode.REPLACE_IF_EXISTS,
false,
write.getTableSchema.names().asScala
)
}
case op => throw new UnsupportedOperationException(s"Write mode $op not supported")
}

}

override def visit(write: ExtensionWrite, context: EmptyVisitationContext): LogicalPlan = {
Expand All @@ -491,14 +504,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
val (format, options) = convertFileFormat(file.getFileFormat.get)

val name = file.getPath.get.split('/').reverse.head
val id = TableIdentifier(name)
val table = CatalogTable(
id,
CatalogTableType.MANAGED,
CatalogStorageFormat.empty,
new StructType(),
None
)
val table = catalogTable(Seq(name))

withChild(child) {
V1Writes.apply(
Expand All @@ -519,6 +525,49 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
}
}

override def visit(ddl: NamedDdl, context: EmptyVisitationContext): LogicalPlan = {
val table = catalogTable(ddl.getNames.asScala, ToSparkType.toStructType(ddl.getTableSchema))

(ddl.getOperation, ddl.getObject) match {
case (DdlOp.CREATE, DdlObject.TABLE) => CreateTableCommand(table, false)
case (DdlOp.DROP, DdlObject.TABLE) => DropTableCommand(table.identifier, false, false, false)
case (DdlOp.DROP_IF_EXIST, DdlObject.TABLE) =>
DropTableCommand(table.identifier, true, false, false)
case op => throw new UnsupportedOperationException(s"Ddl operation $op not supported")
}
}

private def catalogTable(
names: Seq[String],
schema: StructType = new StructType()): CatalogTable = {
val (table, database, catalog) = names match {
case Seq(table) => (table, None, None)
case Seq(database, table) => (table, Some(database), None)
case Seq(catalog, database, table) => (table, Some(database), Some(catalog))
case names =>
throw new UnsupportedOperationException(
s"NamedWrite requires up to three names ([[catalog,] database,] table): $names")
}

val loc = spark.conf.get(StaticSQLConf.WAREHOUSE_PATH.key)
val storage = CatalogStorageFormat(
locationUri = Some(URI.create(f"$loc/$table")),
inputFormat = None,
outputFormat = None,
serde = None,
compressed = false,
properties = Map.empty
)
val id = TableIdentifier(table, database, catalog)
CatalogTable(
id,
CatalogTableType.MANAGED,
storage,
schema,
Some("parquet")
)
}

private def saveMode(mode: CreateMode): SaveMode = mode match {
case CreateMode.APPEND_IF_EXISTS => SaveMode.Append
case CreateMode.REPLACE_IF_EXISTS => SaveMode.Overwrite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,31 @@ import io.substrait.spark.expression._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Sum}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateTableCommand, DropTableCommand}
import org.apache.spark.sql.execution.datasources.{FileFormat => DSFileFormat, HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand, WriteFiles}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, V2SessionCatalog}
import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable}
import org.apache.spark.sql.types.{NullType, StructType}

import io.substrait.`type`.{NamedStruct, Type}
import io.substrait.{proto, relation}
import io.substrait.debug.TreePrinter
import io.substrait.expression.{Expression => SExpression, ExpressionCreator}
import io.substrait.expression.Expression.StructLiteral
import io.substrait.extension.ExtensionCollector
import io.substrait.hint.Hint
import io.substrait.plan.Plan
import io.substrait.relation.AbstractDdlRel.{DdlObject, DdlOp}
import io.substrait.relation.AbstractWriteRel.{CreateMode, OutputMode, WriteOp}
import io.substrait.relation.RelProtoConverter
import io.substrait.relation.Set.SetOp
Expand All @@ -54,7 +58,7 @@ import io.substrait.utils.Util
import java.util
import java.util.{Collections, Optional}

import scala.collection.JavaConverters.asJavaIterableConverter
import scala.collection.JavaConverters.{asJavaIterableConverter, seqAsJavaList}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

Expand All @@ -75,9 +79,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
override def default(p: LogicalPlan): relation.Rel = p match {
case c: CommandResult => visit(c.commandLogicalPlan)
case w: WriteFiles => visit(w.child)
case c: V1WriteCommand => convertDataWritingCommand(c)
case CreateDataSourceTableAsSelectCommand(table, mode, query, names) =>
convertCTAS(table, mode, query, names)
case c: Command => convertCommand(c)
case p: LeafNode => convertReadOperator(p)
case s: SubqueryAlias => visit(s.child)
case v: View => visit(v.child)
Expand Down Expand Up @@ -566,6 +568,28 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}
}

private def convertCommand(command: Command): relation.Rel = command match {
case c: V1WriteCommand => convertDataWritingCommand(c)
case CreateDataSourceTableAsSelectCommand(table, mode, query, names) =>
convertCTAS(table, mode, query, names)
case CreateHiveTableAsSelectCommand(table, query, names, mode) =>
convertCTAS(table, mode, query, names)
case CreateTableCommand(table, _) =>
convertCreateTable(table.identifier.unquotedString.split("\\."), table.schema)
case DropTableCommand(tableName, ifExists, _, _) =>
convertDropTable(tableName.unquotedString.split("\\."), ifExists)
case CreateTable(ResolvedIdentifier(c: V2SessionCatalog, id), tableSchema, _, _, _)
if id.namespace().length > 0 =>
val names = Seq(c.name(), id.namespace()(0), id.name())
convertCreateTable(names, tableSchema)
case DropTable(ResolvedIdentifier(c: V2SessionCatalog, id), ifExists, _)
if id.namespace().length > 0 =>
val names = Seq(c.name(), id.namespace()(0), id.name())
convertDropTable(names, ifExists)
case _ =>
throw new UnsupportedOperationException(s"Unable to convert command: $command")
}

private def convertDataWritingCommand(command: V1WriteCommand): relation.AbstractWriteRel =
command match {
case InsertIntoHadoopFsRelationCommand(
Expand Down Expand Up @@ -600,6 +624,16 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
.tableSchema(outputSchema(child.output, outputColumnNames))
.detail(FileHolder(file))
.build()
case InsertIntoHiveTable(table, _, child, overwrite, _, outputColumnNames, _, _, _, _, _) =>
relation.NamedWrite
.builder()
.input(visit(child))
.operation(WriteOp.INSERT)
.outputMode(OutputMode.UNSPECIFIED)
.createMode(if (overwrite) CreateMode.REPLACE_IF_EXISTS else CreateMode.ERROR_IF_EXISTS)
.names(seqAsJavaList(table.identifier.unquotedString.split("\\.").toList))
.tableSchema(outputSchema(child.output, outputColumnNames))
.build()
case _ =>
throw new UnsupportedOperationException(s"Unable to convert command: ${command.getClass}")
}
Expand All @@ -619,6 +653,29 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
.tableSchema(outputSchema(query.output, outputColumnNames))
.build()

private def convertCreateTable(names: Seq[String], schema: StructType): relation.NamedDdl = {
relation.NamedDdl
.builder()
.operation(DdlOp.CREATE)
.`object`(DdlObject.TABLE)
.names(seqAsJavaList(names))
.tableSchema(ToSubstraitType.toNamedStruct(schema))
.tableDefaults(StructLiteral.builder.nullable(true).build())
.build()
}

private def convertDropTable(names: Seq[String], ifExists: Boolean): relation.NamedDdl = {
relation.NamedDdl
.builder()
.operation(if (ifExists) DdlOp.DROP_IF_EXIST else DdlOp.DROP)
.`object`(DdlObject.TABLE)
.names(seqAsJavaList(names))
.tableSchema(
NamedStruct.builder().struct(Type.Struct.builder().nullable(true).build()).build())
.tableDefaults(StructLiteral.builder.nullable(true).build())
.build()
}

private def createMode(mode: SaveMode): CreateMode = mode match {
case SaveMode.Append => CreateMode.APPEND_IF_EXISTS
case SaveMode.Overwrite => CreateMode.REPLACE_IF_EXISTS
Expand Down
Loading