diff --git a/build.sbt b/build.sbt index 93c73e3..5881c89 100644 --- a/build.sbt +++ b/build.sbt @@ -4,28 +4,18 @@ ThisBuild / scalacOptions ++= Seq("-deprecation", "-release:8") ThisBuild / javacOptions ++= List("-target", "8", "-source", "8") +ThisBuild / version := "0.0.1" + Global / concurrentRestrictions := Seq( Tags.limit(Tags.Test, 1) ) -val Scala212 = "2.12.20" - -val Scala213 = "2.13.16" - -lazy val Spark35 = Spark("3.5.3") - -lazy val Spark34 = Spark("3.4.4") - -lazy val Spark33 = Spark("3.3.4") +val Scala3 = "3.6.3" lazy val Spark32 = Spark("3.2.3") -lazy val Spark31 = Spark("3.1.3") - lazy val ScalaPB0_11 = ScalaPB("0.11.17") -lazy val ScalaPB0_10 = ScalaPB("0.10.11") - lazy val framelessDatasetName = settingKey[String]("frameless-dataset-name") lazy val framelessDatasetVersion = settingKey[String]("frameless-dataset-version") @@ -50,14 +40,11 @@ lazy val `sparksql-scalapb` = (projectMatrix in file("sparksql-scalapb")) .defaultAxes() .settings( libraryDependencies ++= Seq( - "org.typelevel" %% framelessDatasetName.value % framelessDatasetVersion.value, + "org.typelevel" %% framelessDatasetName.value % framelessDatasetVersion.value cross CrossVersion.for3Use2_13, "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.value.scalapbVersion, "com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.value.scalapbVersion % "protobuf", - "org.apache.spark" %% "spark-sql" % spark.value.sparkVersion % "provided", - "org.apache.spark" %% "spark-sql" % spark.value.sparkVersion % "test", - "org.scalatest" %% "scalatest" % "3.2.19" % "test", - "org.scalatestplus" %% "scalacheck-1-17" % "3.2.18.0" % "test", - "com.github.alexarchambault" %% "scalacheck-shapeless_1.16" % "1.3.1" % "test" + "org.apache.spark" %% "spark-sql" % spark.value.sparkVersion % "provided" cross CrossVersion.for3Use2_13, + "org.apache.spark" %% "spark-sql" % spark.value.sparkVersion % "test" cross CrossVersion.for3Use2_13 ), spark := { virtualAxes.value @@ -83,18 +70,18 @@ lazy val `sparksql-scalapb` = (projectMatrix in file("sparksql-scalapb")) }, framelessDatasetName := { spark.value match { - case Spark35 | Spark34 | Spark33 => "frameless-dataset" - case Spark32 => "frameless-dataset-spark32" - case Spark31 => "frameless-dataset-spark31" - case _ => ??? +// case Spark35 | Spark34 | Spark33 => "frameless-dataset" + case Spark32 => "frameless-dataset-spark32" +// case Spark31 => "frameless-dataset-spark31" + case _ => ??? } }, framelessDatasetVersion := { spark.value match { - case Spark35 | Spark34 | Spark33 => "0.16.0" // NPE in 3.4, 3.5 if older lib versions used - case Spark32 => "0.15.0" // Spark3.2 support dropped in ver > 0.15.0 - case Spark31 => "0.14.0" // Spark3.1 support dropped in ver > 0.14.0 - case _ => ??? +// case Spark35 | Spark34 | Spark33 => "0.16.0" // NPE in 3.4, 3.5 if older lib versions used + case Spark32 => "0.15.0" // Spark3.2 support dropped in ver > 0.15.0 +// case Spark31 => "0.14.0" // Spark3.1 support dropped in ver > 0.14.0 + case _ => ??? } }, name := s"sparksql${spark.value.majorVersion}${spark.value.minorVersion}-${scalapb.value.idSuffix}", @@ -107,46 +94,46 @@ lazy val `sparksql-scalapb` = (projectMatrix in file("sparksql-scalapb")) Test / run / fork := true, Test / javaOptions ++= Seq("-Xmx2G") ) +// .customRow( +// scalaVersions = Seq(Scala212, Scala213), +// axisValues = Seq(Spark35, ScalaPB0_11, VirtualAxis.jvm), +// settings = Seq() +// ) +// .customRow( +// scalaVersions = Seq(Scala212, Scala213), +// axisValues = Seq(Spark34, ScalaPB0_11, VirtualAxis.jvm), +// settings = Seq() +// ) +// .customRow( +// scalaVersions = Seq(Scala212, Scala213), +// axisValues = Seq(Spark33, ScalaPB0_11, VirtualAxis.jvm), +// settings = Seq() +// ) .customRow( - scalaVersions = Seq(Scala212, Scala213), - axisValues = Seq(Spark35, ScalaPB0_11, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212, Scala213), - axisValues = Seq(Spark34, ScalaPB0_11, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212, Scala213), - axisValues = Seq(Spark33, ScalaPB0_11, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212, Scala213), + scalaVersions = Seq(Scala3), axisValues = Seq(Spark32, ScalaPB0_11, VirtualAxis.jvm), settings = Seq() ) - .customRow( - scalaVersions = Seq(Scala212), - axisValues = Seq(Spark31, ScalaPB0_11, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212, Scala213), - axisValues = Seq(Spark33, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212, Scala213), - axisValues = Seq(Spark32, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() - ) - .customRow( - scalaVersions = Seq(Scala212), - axisValues = Seq(Spark31, ScalaPB0_10, VirtualAxis.jvm), - settings = Seq() - ) +// .customRow( +// scalaVersions = Seq(Scala212), +// axisValues = Seq(Spark31, ScalaPB0_11, VirtualAxis.jvm), +// settings = Seq() +// ) +// .customRow( +// scalaVersions = Seq(Scala212, Scala213), +// axisValues = Seq(Spark33, ScalaPB0_10, VirtualAxis.jvm), +// settings = Seq() +// ) +// .customRow( +// scalaVersions = Seq(Scala212, Scala213), +// axisValues = Seq(Spark32, ScalaPB0_10, VirtualAxis.jvm), +// settings = Seq() +// ) +// .customRow( +// scalaVersions = Seq(Scala212), +// axisValues = Seq(Spark31, ScalaPB0_10, VirtualAxis.jvm), +// settings = Seq() +// ) ThisBuild / publishTo := sonatypePublishToBundle.value diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/CustomTypedEncoders.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/CustomTypedEncoders.scala index e74927c..7f96a02 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/CustomTypedEncoders.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/CustomTypedEncoders.scala @@ -12,8 +12,8 @@ object CustomTypedEncoders { val timestampToSqlTimestamp: TypedEncoder[Timestamp] = fromInjection( Injection[Timestamp, SQLTimestamp]( - { ts: Timestamp => SQLTimestamp(TimestampHelpers.toMicros(ts)) }, - { timestamp: SQLTimestamp => TimestampHelpers.fromMicros(timestamp.us) } + { (ts: Timestamp) => SQLTimestamp(TimestampHelpers.toMicros(ts)) }, + { (timestamp: SQLTimestamp) => TimestampHelpers.fromMicros(timestamp.us) } ) ) diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala index 0fffcda..68e0170 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/FromCatalystHelpers.scala @@ -24,7 +24,7 @@ trait FromCatalystHelpers { def schemaOptions: SchemaOptions = protoSql.schemaOptions def pmessageFromCatalyst( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], input: Expression ): Expression = { schemaOptions.messageEncoders.get(cmp.scalaDescriptor) match { @@ -55,7 +55,7 @@ trait FromCatalystHelpers { def pmessageFromCatalyst( input: Expression, - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], args: Seq[Expression] ): Expression = { val outputType = ObjectType(classOf[PValue]) @@ -74,7 +74,7 @@ trait FromCatalystHelpers { } def fieldFromCatalyst( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], fd: FieldDescriptor, input: Expression ): Expression = { @@ -118,7 +118,7 @@ trait FromCatalystHelpers { } def singleFieldValueFromCatalyst( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], fd: FieldDescriptor, input: Expression ): Expression = { @@ -182,7 +182,7 @@ case class MyUnresolvedCatalystToExternalMap( @transient keyFunction: Expression => Expression, @transient valueFunction: Expression => Expression, mapType: MapType, - collClass: Class[_] + collClass: Class[?] ) object MyCatalystToExternalMap { diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/JavaHelpers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/JavaHelpers.scala index e46881b..41146f7 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/JavaHelpers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/JavaHelpers.scala @@ -7,7 +7,7 @@ import scalapb.{GeneratedEnum, GeneratedEnumCompanion, GeneratedMessage, Generat object JavaHelpers { def enumToString( - cmp: GeneratedEnumCompanion[_], + cmp: GeneratedEnumCompanion[?], value: GeneratedEnum ): UTF8String = { UTF8String.fromString( @@ -52,7 +52,7 @@ object JavaHelpers { } def penumFromString( - cmp: GeneratedEnumCompanion[_], + cmp: GeneratedEnumCompanion[?], inputUtf8: UTF8String ): PValue = { val input = inputUtf8.toString @@ -73,7 +73,7 @@ object JavaHelpers { } } - def mkMap(cmp: GeneratedMessageCompanion[_], args: ArrayData): Map[FieldDescriptor, PValue] = { + def mkMap(cmp: GeneratedMessageCompanion[?], args: ArrayData): Map[FieldDescriptor, PValue] = { cmp.scalaDescriptor.fields .zip(args.array) .filter { @@ -89,7 +89,7 @@ object JavaHelpers { } def mkPRepeatedMap( - mapEntryCmp: GeneratedMessageCompanion[_], + mapEntryCmp: GeneratedMessageCompanion[?], args: Vector[(PValue, PValue)] ): PValue = { val keyDesc = mapEntryCmp.scalaDescriptor.findFieldByNumber(1).get @@ -102,7 +102,7 @@ object JavaHelpers { // ExternalMapToCatalyst only needs iterator. We create this view into s to // avoid making a copy. def mkMap(s: Seq[GeneratedMessage]): Map[Any, Any] = - scalapb.spark.internal.MapHelpers.fromIterator { + MapHelpers.fromIterator { if (s.isEmpty) Iterator.empty else { val cmp = s.head.companion diff --git a/sparksql-scalapb/src/main/scala-2.13/scalapb/spark/internal/MapHelpers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/MapHelpers.scala similarity index 93% rename from sparksql-scalapb/src/main/scala-2.13/scalapb/spark/internal/MapHelpers.scala rename to sparksql-scalapb/src/main/scala/scalapb/spark/MapHelpers.scala index b0b0a76..92a4e57 100644 --- a/sparksql-scalapb/src/main/scala-2.13/scalapb/spark/internal/MapHelpers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/MapHelpers.scala @@ -1,4 +1,4 @@ -package scalapb.spark.internal +package scalapb.spark private[spark] object MapHelpers { def fromIterator[K, V](it: => Iterator[(K, V)]): Map[K, V] = new Map[K, V] { diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/SchemaOptions.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/SchemaOptions.scala index 08a21b7..464e7c5 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/SchemaOptions.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/SchemaOptions.scala @@ -8,7 +8,7 @@ import frameless.TypedEncoder case class SchemaOptions( columnNaming: ColumnNaming, retainPrimitiveWrappers: Boolean, - messageEncoders: Map[Descriptor, TypedEncoder[_]] + messageEncoders: Map[Descriptor, TypedEncoder[?]] ) { def withScalaNames = copy(columnNaming = ColumnNaming.ScalaNames) @@ -16,7 +16,7 @@ case class SchemaOptions( def withRetainedPrimitiveWrappers = copy(retainPrimitiveWrappers = true) - def withMessageEncoders(messageEncoders: Map[Descriptor, TypedEncoder[_]]) = + def withMessageEncoders(messageEncoders: Map[Descriptor, TypedEncoder[?]]) = copy(messageEncoders = messageEncoders) def addMessageEncoder[T <: GeneratedMessage]( diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/ToCatalystHelpers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/ToCatalystHelpers.scala index b7df96e..14229c0 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/ToCatalystHelpers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/ToCatalystHelpers.scala @@ -20,7 +20,7 @@ trait ToCatalystHelpers { def schemaOptions: SchemaOptions def messageToCatalyst( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], input: Expression ): Expression = { schemaOptions.messageEncoders.get(cmp.scalaDescriptor) match { @@ -52,7 +52,7 @@ trait ToCatalystHelpers { } def fieldGetterAndTransformer( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], fd: FieldDescriptor ): (Expression => Expression, Expression => Expression) = { def messageFieldCompanion = cmp.messageCompanionForFieldNumber(fd.number) @@ -73,7 +73,7 @@ trait ToCatalystHelpers { Nil ), "findFieldByNumber", - ObjectType(classOf[Option[_]]), + ObjectType(classOf[Option[?]]), Literal(fd.number) :: Nil ), "get", @@ -86,18 +86,18 @@ trait ToCatalystHelpers { inputObject, "getFieldByNumber", if (fd.isRepeated) - ObjectType(classOf[Seq[_]]) + ObjectType(classOf[Seq[?]]) else ObjectType(messageFieldCompanion.defaultInstance.getClass), Literal(fd.number, IntegerType) :: Nil ) if (!isMessage) { - (getField, { e: Expression => singularFieldToCatalyst(fd, e) }) + (getField, { (e: Expression) => singularFieldToCatalyst(fd, e) }) } else { ( getFieldByNumber, - { e: Expression => + { (e: Expression) => messageToCatalyst(messageFieldCompanion, e) } ) @@ -105,7 +105,7 @@ trait ToCatalystHelpers { } def fieldToCatalyst( - cmp: GeneratedMessageCompanion[_], + cmp: GeneratedMessageCompanion[?], fd: FieldDescriptor, inputObject: Expression ): Expression = { @@ -132,7 +132,7 @@ trait ToCatalystHelpers { ExternalMapToCatalyst( StaticInvoke( JavaHelpers.getClass, - ObjectType(classOf[Map[_, _]]), + ObjectType(classOf[Map[?, ?]]), "mkMap", fieldGetter(inputObject) :: Nil ), @@ -152,7 +152,7 @@ trait ToCatalystHelpers { else { val getter = StaticInvoke( JavaHelpers.getClass, - ObjectType(classOf[Vector[_]]), + ObjectType(classOf[Vector[?]]), "vectorFromPValue", fieldGetter(inputObject) :: Nil ) diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/TypeMappers.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/TypeMappers.scala index 99d6630..37ebec1 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/TypeMappers.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/TypeMappers.scala @@ -9,9 +9,9 @@ import java.time.Instant object TypeMappers { implicit val googleTsToSqlTsMapper: TypeMapper[GoogleTimestamp, SQLTimestamp] = - TypeMapper({ googleTs: GoogleTimestamp => + TypeMapper({ (googleTs: GoogleTimestamp) => SQLTimestamp.from(Instant.ofEpochSecond(googleTs.seconds, googleTs.nanos)) - })({ sqlTs: SQLTimestamp => + })({ (sqlTs: SQLTimestamp) => val instant = sqlTs.toInstant GoogleTimestamp(instant.getEpochSecond, instant.getNano) }) diff --git a/sparksql-scalapb/src/main/scala/scalapb/spark/TypedEncoders.scala b/sparksql-scalapb/src/main/scala/scalapb/spark/TypedEncoders.scala index 2ec0564..2b3ef72 100644 --- a/sparksql-scalapb/src/main/scala/scalapb/spark/TypedEncoders.scala +++ b/sparksql-scalapb/src/main/scala/scalapb/spark/TypedEncoders.scala @@ -28,11 +28,11 @@ trait TypedEncoders extends FromCatalystHelpers with ToCatalystHelpers with Seri val reads = Invoke( Literal.fromObject(cmp), "messageReads", - ObjectType(classOf[Reads[_]]), + ObjectType(classOf[Reads[?]]), Nil ) - val read = Invoke(reads, "read", ObjectType(classOf[Function[_, _]])) + val read = Invoke(reads, "read", ObjectType(classOf[Function[?, ?]])) val ret = Invoke(read, "apply", ObjectType(ct.runtimeClass), expr :: Nil) ret diff --git a/sparksql-scalapb/src/test/protobuf/all_types2.proto b/sparksql-scalapb/src/test/protobuf/all_types2.proto deleted file mode 100644 index 163e78e..0000000 --- a/sparksql-scalapb/src/test/protobuf/all_types2.proto +++ /dev/null @@ -1,172 +0,0 @@ -syntax = "proto2"; - -package scalapb.spark.test; - -import "google/protobuf/any.proto"; -import "google/protobuf/duration.proto"; -import "google/protobuf/timestamp.proto"; -import "scalapb/scalapb.proto"; - -option (scalapb.options) = { - preserve_unknown_fields: false -}; - -message Int32Test { - optional int32 optional_int32 = 1; - optional uint32 optional_uint32 = 2; - optional sint32 optional_sint32 = 3; - optional fixed32 optional_fixed32 = 4; - optional sfixed32 optional_sfixed32 = 5; - - required int32 required_int32 = 6; - required uint32 required_uint32 = 7; - required sint32 required_sint32 = 8; - required fixed32 required_fixed32 = 9; - required sfixed32 required_sfixed32 = 10; - - repeated int32 repeated_int32 = 11; - repeated uint32 repeated_uint32 = 12; - repeated sint32 repeated_sint32 = 13; - repeated fixed32 repeated_fixed32 = 14; - repeated sfixed32 repeated_sfixed32 = 15; -} - -message Int64Test { - optional int64 optional_int64 = 1; - optional uint64 optional_uint64 = 2; - optional sint64 optional_sint64 = 3; - optional fixed64 optional_fixed64 = 4; - optional sfixed64 optional_sfixed64 = 5; - - required int64 required_int64 = 6; - required uint64 required_uint64 = 7; - required sint64 required_sint64 = 8; - required fixed64 required_fixed64 = 9; - required sfixed64 required_sfixed64 = 10; - - repeated int64 repeated_int64 = 11; - repeated uint64 repeated_uint64 = 12; - repeated sint64 repeated_sint64 = 13; - repeated fixed64 repeated_fixed64 = 14; - repeated sfixed64 repeated_sfixed64 = 15; -} - -message DoubleTest { - optional double optional_double = 1; - required double required_double = 2; - repeated double repeated_double = 3; -} - -message FloatTest { - optional float optional_float = 1; - required float required_float = 2; - repeated float repeated_float = 3; -} - -message StringTest { - optional string optional_string = 1; - required string required_string = 2; - repeated string repeated_string = 3; -} - -message BoolTest { - optional bool optional_bool = 1; - required bool required_bool = 2; - repeated bool repeated_bool = 3; -} - -message BytesTest { - optional bytes optional_bytes = 1; - required bytes required_bytes = 2; - repeated bytes repeated_bytes = 3; -} - -enum TopLevelEnum { - FOREIGN_FOO = 4; - FOREIGN_BAR = 5; - FOREIGN_BAZ = 6; -} - -enum TopLevelEnum0 { - DEFAULT = 0; - TOP_FOO = 4; - TOP_BAR = 5; - TOP_BAZ = 6; -} - -message EnumTest { - enum NestedEnum { - FOO = 1; - BAR = 2; - BAZ = 3; - NEG = -1; // Intentionally negative. - } - optional NestedEnum optional_nested_enum = 1; - required NestedEnum required_nested_enum = 2; - repeated NestedEnum repeated_nested_enum = 3; - optional TopLevelEnum optional_top_level_enum = 4; - required TopLevelEnum required_top_level_enum = 5; - repeated TopLevelEnum repeated_top_level_enum = 6; -} - -message TopLevelMessage { - optional int32 c = 1; -} - -message MessageTest { - message NestedMessage { - optional int32 bb = 1; - } - optional NestedMessage optional_nested_message = 1; - required NestedMessage required_nested_message = 2; - repeated NestedMessage repeated_nested_message = 3; - - optional TopLevelMessage optional_top_level_message = 4; - required TopLevelMessage required_top_level_message = 5; - repeated TopLevelMessage repeated_top_level_message = 6; -} - -message OneofTest { - oneof oneof_field { - uint32 oneof_uint32 = 111; - MessageTest.NestedMessage oneof_nested_message = 112; - string oneof_string = 113; - bytes oneof_bytes = 114; - } -} - -message Level1 { - message Level2 { - message Level3 { - optional bytes c = 2; - } - optional Level3 level3 = 1; - optional int32 b = 2; - } - optional Level2 level2 = 1; - optional string a = 2; -} - -message AnyTest { - optional google.protobuf.Any any = 1; -} - -message WellKnownTypes { - optional google.protobuf.Timestamp timestamp = 1; - optional google.protobuf.Duration duration = 2; - repeated google.protobuf.Timestamp timestamps = 3; - repeated google.protobuf.Duration durations = 4; -} - -message MapTypes { - map int32_to_bool = 1; - map int32_to_bytes = 2; - map int32_to_double = 3; - map int32_to_float = 4; - map int32_to_int32 = 5; - map int32_to_int64 = 6; - map int32_to_string = 7; - map string_to_int32 = 8; - map int32_to_message = 9; - map int32_to_enum = 10; -} \ No newline at end of file diff --git a/sparksql-scalapb/src/test/protobuf/all_types3.proto b/sparksql-scalapb/src/test/protobuf/all_types3.proto deleted file mode 100644 index 6b4d8ed..0000000 --- a/sparksql-scalapb/src/test/protobuf/all_types3.proto +++ /dev/null @@ -1,169 +0,0 @@ -syntax = "proto3"; - -package scalapb.spark.test3; - -import "google/protobuf/any.proto"; -import "google/protobuf/wrappers.proto"; -import "google/protobuf/duration.proto"; -import "google/protobuf/timestamp.proto"; -import "scalapb/scalapb.proto"; - -option (scalapb.options) = { - preserve_unknown_fields: false -}; - -message Int32Test { - int32 optional_int32 = 1; - uint32 optional_uint32 = 2; - sint32 optional_sint32 = 3; - fixed32 optional_fixed32 = 4; - sfixed32 optional_sfixed32 = 5; - - repeated int32 repeated_int32 = 11; - repeated uint32 repeated_uint32 = 12; - repeated sint32 repeated_sint32 = 13; - repeated fixed32 repeated_fixed32 = 14; - repeated sfixed32 repeated_sfixed32 = 15; -} - -message Int64Test { - int64 optional_int64 = 1; - uint64 optional_uint64 = 2; - sint64 optional_sint64 = 3; - fixed64 optional_fixed64 = 4; - sfixed64 optional_sfixed64 = 5; - - repeated int64 repeated_int64 = 11; - repeated uint64 repeated_uint64 = 12; - repeated sint64 repeated_sint64 = 13; - repeated fixed64 repeated_fixed64 = 14; - repeated sfixed64 repeated_sfixed64 = 15; -} - -message DoubleTest { - double optional_double = 1; - repeated double repeated_double = 3; -} - -message FloatTest { - float optional_float = 1; - repeated float repeated_float = 3; -} - -message StringTest { - string optional_string = 1; - repeated string repeated_string = 3; -} - -message BoolTest { - bool optional_bool = 1; - repeated bool repeated_bool = 3; -} - -message BytesTest { - bytes optional_bytes = 1; - repeated bytes repeated_bytes = 3; -} - -enum TopLevelEnum { - EMPTY = 0; - FOREIGN_FOO = 4; - FOREIGN_BAR = 5; - FOREIGN_BAZ = 6; -} - -message EnumTest { - enum NestedEnum { - UNKNOWN = 0; - FOO = 1; - BAR = 2; - BAZ = 3; - NEG = -1; // Intentionally negative. - } - NestedEnum optional_nested_enum = 1; - repeated NestedEnum repeated_nested_enum = 3; - TopLevelEnum optional_top_level_enum = 4; - repeated TopLevelEnum repeated_top_level_enum = 6; -} - -message TopLevelMessage { - int32 c = 1; -} - -message MessageTest { - message NestedMessage { - int32 bb = 1; - } - NestedMessage optional_nested_message = 1; - repeated NestedMessage repeated_nested_message = 3; - - TopLevelMessage optional_top_level_message = 4; - repeated TopLevelMessage repeated_top_level_message = 6; -} - -message OneofTest { - oneof oneof_field { - uint32 oneof_uint32 = 111; - MessageTest.NestedMessage oneof_nested_message = 112; - string oneof_string = 113; - bytes oneof_bytes = 114; - } -} - -message Level1 { - message Level2 { - message Level3 { - bytes c = 2; - } - Level3 level3 = 1; - int32 b = 2; - } - Level2 level2 = 1; - string a = 2; -} - -message AnyTest { - google.protobuf.Any any = 1; -} - -message WrappersTest { - google.protobuf.BoolValue bool_value = 1; - google.protobuf.BytesValue bytes_value = 2; - google.protobuf.DoubleValue double_value = 3; - google.protobuf.FloatValue float_value = 4; - google.protobuf.Int32Value int32_value = 5; - google.protobuf.Int64Value int64_value = 6; - google.protobuf.StringValue string_value = 7; - google.protobuf.UInt32Value uint32_value = 8; - google.protobuf.UInt64Value uint64_Value = 9; - - repeated google.protobuf.BoolValue bool_repeated_Value = 11; - repeated google.protobuf.BytesValue bytes_repeated_Value = 12; - repeated google.protobuf.DoubleValue double_repeated_Value = 13; - repeated google.protobuf.FloatValue float_repeated_Value = 14; - repeated google.protobuf.Int32Value int32_repeated_Value = 15; - repeated google.protobuf.Int64Value int64_repeated_Value = 16; - repeated google.protobuf.StringValue string_repeated_Value = 17; - repeated google.protobuf.UInt32Value uint32_repeated_Value = 18; - repeated google.protobuf.UInt64Value uint64_repeated_Value = 19; -} - -message WellKnownTypes { - google.protobuf.Timestamp timestamp = 1; - google.protobuf.Duration duration = 2; - repeated google.protobuf.Timestamp timestamps = 3; - repeated google.protobuf.Duration durations = 4; -} - -message MapTypes { - map int32_to_bool = 1; - map int32_to_bytes = 2; - map int32_to_double = 3; - map int32_to_float = 4; - map int32_to_int32 = 5; - map int32_to_int64 = 6; - map int32_to_string = 7; - map string_to_int32 = 8; - map int32_to_message = 9; - map int32_to_enum = 10; -} \ No newline at end of file diff --git a/sparksql-scalapb/src/test/protobuf/base.proto b/sparksql-scalapb/src/test/protobuf/base.proto deleted file mode 100644 index 055f344..0000000 --- a/sparksql-scalapb/src/test/protobuf/base.proto +++ /dev/null @@ -1,10 +0,0 @@ -syntax = "proto2"; - -option java_package = "com.example.protos"; - -message Base { - enum Color { - BLUE = 1; - RED = 2; - } -} diff --git a/sparksql-scalapb/src/test/protobuf/customizations.proto b/sparksql-scalapb/src/test/protobuf/customizations.proto deleted file mode 100644 index f5911f0..0000000 --- a/sparksql-scalapb/src/test/protobuf/customizations.proto +++ /dev/null @@ -1,28 +0,0 @@ -syntax = "proto3"; - -package scalapb.spark.test3; - -import "google/protobuf/timestamp.proto"; -import "scalapb/scalapb.proto"; - -message StructFromGoogleTimestamp { - google.protobuf.Timestamp google_ts = 1; -} - -// Needed to be able to represent the case class property as java.sql.Timestamp -option (scalapb.options) = { - import: "scalapb.spark.TypeMappers._" -}; - -message SQLTimestampFromGoogleTimestamp { - google.protobuf.Timestamp google_ts_as_sql_ts = 1 [(scalapb.field).type = "java.sql.Timestamp"]; -} - -message BothTimestampTypes { - google.protobuf.Timestamp google_ts = 1; - google.protobuf.Timestamp google_ts_as_sql_ts = 2 [(scalapb.field).type = "java.sql.Timestamp"]; -} - -message TimestampTypesMap { - map map_field = 1; -} diff --git a/sparksql-scalapb/src/test/protobuf/defaults.proto b/sparksql-scalapb/src/test/protobuf/defaults.proto deleted file mode 100644 index 6374526..0000000 --- a/sparksql-scalapb/src/test/protobuf/defaults.proto +++ /dev/null @@ -1,27 +0,0 @@ -syntax = "proto2"; - -option java_package = "com.example.protos"; - -message DefaultsRequired { - required int32 i32Value = 1; - required int64 i64Value = 2; - required uint32 u32Value = 3; - required uint64 u64Value = 4; - required double dValue = 5; - required float fValue = 6; - required bool bValue = 7; - required string sValue = 8; - required bytes binaryValue = 9; -} - -message DefaultsOptional { - optional int32 i32Value = 1; - optional int64 i64Value = 2; - optional uint32 u32Value = 3; - optional uint64 u64Value = 4; - optional double dValue = 5; - optional float fValue = 6; - optional bool bValue = 7; - optional string sValue = 8; - optional bytes binaryValue = 9; -} diff --git a/sparksql-scalapb/src/test/protobuf/defaultsv3.proto b/sparksql-scalapb/src/test/protobuf/defaultsv3.proto deleted file mode 100644 index 49c8a9f..0000000 --- a/sparksql-scalapb/src/test/protobuf/defaultsv3.proto +++ /dev/null @@ -1,15 +0,0 @@ -syntax = "proto3"; - -option java_package = "com.example.protos"; - -message DefaultsV3 { - int32 i32Value = 1; - int64 i64Value = 2; - uint32 u32Value = 3; - uint64 u64Value = 4; - double dValue = 5; - float fValue = 6; - bool bValue = 7; - string sValue = 8; - bytes binaryValue = 9; -} diff --git a/sparksql-scalapb/src/test/protobuf/demo.proto b/sparksql-scalapb/src/test/protobuf/demo.proto deleted file mode 100644 index 9c1f702..0000000 --- a/sparksql-scalapb/src/test/protobuf/demo.proto +++ /dev/null @@ -1,54 +0,0 @@ -syntax = "proto2"; - -option java_package = "com.example.protos"; - -import "base.proto"; -import "google/protobuf/any.proto"; - -enum Gender { - MALE = 1; - FEMALE = 2; -} - -message Address { - optional string street = 1; - optional string city = 2; -} - -message SimplePerson { - optional string name = 1; - optional int32 age = 2; - repeated string tags = 3; - optional Address address = 4; - repeated int32 nums = 5; -} - -message Person { - optional string name = 1; - optional int32 age = 2; - optional Gender gender = 3; - repeated string tags = 4; - repeated Address addresses = 5; - optional Base base = 6; - - message Inner { - enum InnerEnum { - V0 = 0; - V1 = 1; - } - - optional InnerEnum inner_value = 1; - } - optional Inner inner = 7; - optional bytes data = 8; -} - -message Event { - optional string eventId = 1; - optional google.protobuf.Any action = 2; -} - -message Hit { - optional bytes id = 1; - optional string target = 2; -} diff --git a/sparksql-scalapb/src/test/protobuf/maps.proto b/sparksql-scalapb/src/test/protobuf/maps.proto deleted file mode 100644 index c531e41..0000000 --- a/sparksql-scalapb/src/test/protobuf/maps.proto +++ /dev/null @@ -1,7 +0,0 @@ -syntax = "proto2"; - -option java_package = "com.example.protos"; - -message MapTest { - map attributes = 16; -} \ No newline at end of file diff --git a/sparksql-scalapb/src/test/protobuf/schema_bug.proto b/sparksql-scalapb/src/test/protobuf/schema_bug.proto deleted file mode 100644 index e6f67c2..0000000 --- a/sparksql-scalapb/src/test/protobuf/schema_bug.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; - -package scalapb.spark.schema.bug; - -message Write { - repeated RepeatedNestedWrite nested_field = 6; - string additional_field = 14; -} - -message RepeatedNestedWrite { - RepeatedOmitWrite omit_for_read = 2; - string field_one = 7; - string field_two = 8; - Nested persisted = 9; -} - -message RepeatedOmitWrite { - string field_one = 5; - string field_two = 6; -} - -message Nested { - string field_one = 5; -} - -message Read { - repeated RepeatedNestedRead nested_field = 6; - string additional_field = 14; -} - -message RepeatedNestedRead { - string field_one = 7; - string field_two = 8; - Nested persisted = 9; -} diff --git a/sparksql-scalapb/src/test/protobuf/wrappers.proto b/sparksql-scalapb/src/test/protobuf/wrappers.proto deleted file mode 100644 index a0620c9..0000000 --- a/sparksql-scalapb/src/test/protobuf/wrappers.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; - -option java_package = "com.example.protos"; - -import "google/protobuf/wrappers.proto"; - -message PrimitiveWrappers { - google.protobuf.Int32Value int_value = 1; - google.protobuf.StringValue string_value = 2; - - repeated google.protobuf.Int32Value ints = 3; - repeated google.protobuf.StringValue strings = 4; -} \ No newline at end of file diff --git a/sparksql-scalapb/src/test/resources/address.json b/sparksql-scalapb/src/test/resources/address.json deleted file mode 100644 index 6d95903..0000000 --- a/sparksql-scalapb/src/test/resources/address.json +++ /dev/null @@ -1 +0,0 @@ -{"foo": "bar"} diff --git a/sparksql-scalapb/src/test/resources/person_null_repeated.json b/sparksql-scalapb/src/test/resources/person_null_repeated.json deleted file mode 100644 index 710c3a5..0000000 --- a/sparksql-scalapb/src/test/resources/person_null_repeated.json +++ /dev/null @@ -1,4 +0,0 @@ -{"tags": ["foo", "bar"]} -{"tags": []} -{"tags": null} -{} diff --git a/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/AllTypesSpec.scala b/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/AllTypesSpec.scala deleted file mode 100644 index 17d7fba..0000000 --- a/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/AllTypesSpec.scala +++ /dev/null @@ -1,128 +0,0 @@ -package scalapb.spark - -import org.apache.spark.sql.{Dataset, Encoder, SparkSession} -import org.scalacheck.Arbitrary -import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks -import org.scalatest.BeforeAndAfterAll -import scalapb.spark.test.{all_types2 => AT2} -import scalapb.spark.test3.{all_types3 => AT3} -import scalapb.{GeneratedMessage, GeneratedMessageCompanion, Message} -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers -import scala.reflect.ClassTag - -class AllTypesSpec - extends AnyFlatSpec - with Matchers - with BeforeAndAfterAll - with ScalaCheckDrivenPropertyChecks { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .getOrCreate() - - import ArbitraryProtoUtils._ - import org.scalacheck.ScalacheckShapeless._ - import spark.implicits.{newProductEncoder => _} - - def verifyTypes[ - T <: GeneratedMessage: Arbitrary: GeneratedMessageCompanion: ClassTag - ]( - protoSQL: ProtoSQL - ): Unit = - forAll { (n: Seq[T]) => - import protoSQL.implicits._ - - // ProtoSQL conversion to dataframe - val df1 = protoSQL.createDataFrame(spark, n) - - // Creates dataset using encoder deserialization: - val ds1: Dataset[T] = df1.as[T] - ds1.collect() must contain theSameElementsAs (n) - // Creates dataframe using encoder serialization: - val ds2 = spark.createDataset(n) - ds2.collect() must contain theSameElementsAs (n) - } - - def verifyTypes[ - T <: GeneratedMessage: Arbitrary: GeneratedMessageCompanion: ClassTag - ]: Unit = - verifyTypes[T](ProtoSQL) - - "AllTypes" should "work for int32" in { - verifyTypes[AT2.Int32Test] - verifyTypes[AT3.Int32Test] - } - - it should "work for int64" in { - verifyTypes[AT2.Int64Test] - verifyTypes[AT3.Int64Test] - } - - it should "work for bools" in { - verifyTypes[AT2.BoolTest] - verifyTypes[AT3.BoolTest] - } - - it should "work for strings" in { - verifyTypes[AT2.StringTest] - verifyTypes[AT3.StringTest] - } - - it should "work for floats" in { - verifyTypes[AT2.FloatTest] - verifyTypes[AT3.FloatTest] - } - - it should "work for doubles" in { - verifyTypes[AT2.DoubleTest] - verifyTypes[AT3.DoubleTest] - } - - it should "work for bytes" in { - verifyTypes[AT2.BytesTest] - verifyTypes[AT3.BytesTest] - } - - it should "work for enums" in { - verifyTypes[AT2.EnumTest] - verifyTypes[AT3.EnumTest] - } - - it should "work for messages" in { - verifyTypes[AT2.MessageTest] - verifyTypes[AT3.MessageTest] - } - - it should "work for oneofs" in { - verifyTypes[AT2.OneofTest] - verifyTypes[AT3.OneofTest] - } - - it should "work for levels" in { - verifyTypes[AT2.Level1] - verifyTypes[AT3.Level1] - } - - it should "work for any" in { - verifyTypes[AT2.AnyTest] - verifyTypes[AT3.AnyTest] - } - - it should "work for time types" in { - verifyTypes[AT2.WellKnownTypes] - verifyTypes[AT3.WellKnownTypes] - } - - it should "work for wrapper types" in { - verifyTypes[AT3.WrappersTest] - verifyTypes[AT3.WrappersTest](ProtoSQL.withRetainedPrimitiveWrappers) - verifyTypes[AT3.WrappersTest](new ProtoSQL(SchemaOptions().withScalaNames)) - } - - it should "work for maps" in { - verifyTypes[AT2.MapTypes](ProtoSQL) - verifyTypes[AT3.MapTypes](ProtoSQL) - } -} diff --git a/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/ArbitraryProtoUtils.scala b/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/ArbitraryProtoUtils.scala deleted file mode 100644 index e0ac4bb..0000000 --- a/sparksql-scalapb/src/test/scala-2.12/scalapb/spark/ArbitraryProtoUtils.scala +++ /dev/null @@ -1,52 +0,0 @@ -package scalapb.spark - -import com.google.protobuf.ByteString -import org.scalacheck.Arbitrary -import org.scalacheck.derive.MkArbitrary -import scalapb.spark.test.{all_types2 => AT2} -import scalapb.spark.test3.{all_types3 => AT3} -import scalapb.{GeneratedEnum, GeneratedEnumCompanion, GeneratedMessage, Message} -import shapeless.Strict -import org.scalacheck.Gen -import scalapb.UnknownFieldSet - -object ArbitraryProtoUtils { - import org.scalacheck.ScalacheckShapeless._ - - implicit val arbitraryBS = Arbitrary( - implicitly[Arbitrary[Array[Byte]]].arbitrary - .map(t => ByteString.copyFrom(t)) - ) - - // Default scalacheck-shapeless would chose Unrecognized instances with recognized values. - private def fixEnum[A <: GeneratedEnum]( - e: A - )(implicit cmp: GeneratedEnumCompanion[A]): A = { - if (e.isUnrecognized) cmp.values.find(_.value == e.value).getOrElse(e) - else e - } - - def arbitraryEnum[A <: GeneratedEnum: Arbitrary: GeneratedEnumCompanion] = { - Arbitrary(implicitly[Arbitrary[A]].arbitrary.map(fixEnum(_))) - } - - implicit val arbitraryUnknownFields = Arbitrary( - Gen.const(UnknownFieldSet.empty) - ) - - implicit val nestedEnum2 = arbitraryEnum[AT2.EnumTest.NestedEnum] - - implicit val nestedEnum3 = arbitraryEnum[AT3.EnumTest.NestedEnum] - - implicit val topLevelEnum2 = arbitraryEnum[AT2.TopLevelEnum] - - implicit val topLevelEnum0 = arbitraryEnum[AT2.TopLevelEnum0] - - implicit val topLevelEnum3 = arbitraryEnum[AT3.TopLevelEnum] - - implicit def arbitraryMessage[A <: GeneratedMessage](implicit - ev: Strict[MkArbitrary[A]] - ) = { - implicitly[Arbitrary[A]] - } -} diff --git a/sparksql-scalapb/src/test/scala/DefaultsSpec.scala b/sparksql-scalapb/src/test/scala/DefaultsSpec.scala deleted file mode 100644 index 0c47137..0000000 --- a/sparksql-scalapb/src/test/scala/DefaultsSpec.scala +++ /dev/null @@ -1,58 +0,0 @@ -package scalapb.spark - -import org.apache.spark.sql.{Row, SparkSession} -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers - -class DefaultsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .getOrCreate() - - "Proto2 RDD[DefaultsRequired]" should "have non-null default values after converting to Dataframe" in { - import com.example.protos.defaults.DefaultsRequired - val defaults = DefaultsRequired.defaultInstance - val row = ProtoSQL.createDataFrame(spark, Seq(defaults)).collect().head - val expected = Row( - defaults.i32Value, - defaults.i64Value, - defaults.u32Value, - defaults.u64Value, - defaults.dValue, - defaults.fValue, - defaults.bValue, - defaults.sValue, - defaults.binaryValue.toByteArray - ) - row must be(expected) - } - - "Proto2 RDD[DefaultsOptional]" should "have null values after converting to Dataframe" in { - import com.example.protos.defaults.DefaultsOptional - val defaults = DefaultsOptional.defaultInstance - val row = ProtoSQL.createDataFrame(spark, Seq(defaults)).collect().head - val expected = Row(null, null, null, null, null, null, null, null, null) - row must be(expected) - } - - "Proto3 RDD[DefaultsV3]" should "have non-null default values after converting to Dataframe" in { - import com.example.protos.defaultsv3.DefaultsV3 - val defaults = DefaultsV3.defaultInstance - val row = ProtoSQL.createDataFrame(spark, Seq(defaults)).collect().head - val expected = Row( - defaults.i32Value, - defaults.i64Value, - defaults.u32Value, - defaults.u64Value, - defaults.dValue, - defaults.fValue, - defaults.bValue, - defaults.sValue, - defaults.binaryValue.toByteArray - ) - row must be(expected) - } -} diff --git a/sparksql-scalapb/src/test/scala/MapsSpec.scala b/sparksql-scalapb/src/test/scala/MapsSpec.scala deleted file mode 100644 index cbef47d..0000000 --- a/sparksql-scalapb/src/test/scala/MapsSpec.scala +++ /dev/null @@ -1,29 +0,0 @@ -package scalapb.spark - -import com.example.protos.maps._ -import org.scalatest.matchers.should.Matchers -import org.scalatest.flatspec.AnyFlatSpec -import org.apache.spark.sql.SparkSession -import org.scalatest.BeforeAndAfterAll - -class MapsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .getOrCreate() - - import spark.implicits.StringToColumn - import Implicits._ - - val data = Seq( - MapTest(attributes = Map("foo" -> "bar")) - ) - - "converting maps to df" should "work" in { - val df = ProtoSQL.createDataFrame(spark, data) - val res = df.as[MapTest].map(r => r) - - res.show() - } -} diff --git a/sparksql-scalapb/src/test/scala/PersonSpec.scala b/sparksql-scalapb/src/test/scala/PersonSpec.scala deleted file mode 100644 index 96c021c..0000000 --- a/sparksql-scalapb/src/test/scala/PersonSpec.scala +++ /dev/null @@ -1,361 +0,0 @@ -package scalapb.spark - -import com.example.protos.demo.Person.Inner.InnerEnum -import com.example.protos.demo.{Address, Event, Gender, Hit, Person, SimplePerson} -import com.google.protobuf.ByteString -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession, types, functions => F} -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers - -import java.sql.Timestamp - -case class InnerLike(inner_value: String) - -case class AddressLike(street: Option[String], city: Option[String]) - -case class BaseLike() - -case class PersonLike( - name: String, - age: Int, - addresses: Seq[AddressLike], - gender: String, - tags: Seq[String] = Seq.empty, - base: Option[BaseLike] = None, - inner: Option[InnerLike] = None, - data: Option[Array[Byte]] = None, - address: Option[AddressLike] = None, - nums: Vector[Int] = Vector.empty -) - -case class OuterCaseClass(x: Person, y: String) -case class OuterCaseClassTimestamp(x: Person, y: Timestamp) - -class PersonSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .getOrCreate() - - import spark.implicits.StringToColumn - import Implicits._ - - val TestPerson = Person().update( - _.name := "Owen M", - _.age := 35, - _.gender := Gender.MALE, - _.addresses := Seq( - Address().update( - _.city := "San Francisco" - ) - ), - _.data := ByteString.copyFrom(Array[Byte](1, 2, 3)) - ) - - "mapping datasets" should "work" in { - val s = ( - 14, - SimplePerson( - name = Some("foo"), - address = Some(Address(street = Some("St"), city = Some("Ct"))) - ), - 17 - ) - val ds1 = spark.createDataset(Seq(s)).map(_._2) - ds1.collect() must contain theSameElementsAs (Seq(s._2)) - ds1.map(_.getAddress).collect() must contain theSameElementsAs (Seq( - s._2.getAddress - )) - spark - .createDataset(Seq(s)) - .toDF() - .select($"_2.address.*") - .as[Address] - .collect() must contain theSameElementsAs (Seq(s._2.getAddress)) - } - - "Creating person dataset" should "work" in { - val s = Seq(Person().withName("Foo"), Person().withName("Bar")) - - val ds = spark.sqlContext.createDataset(s) - ds.count() must be(2) - } - - "Creating enum dataset" should "work" in { - val gendersStr = Seq((1, "MALE"), (2, "MALE"), (3, "FEMALE"), (5, "15")) - val gendersObj = Seq( - (1, Gender.MALE), - (2, Gender.MALE), - (3, Gender.FEMALE), - (5, Gender.Unrecognized(15)) - ) - - spark - .createDataset(gendersStr) - .as[(Int, Gender)] - .collect() - .toVector must contain theSameElementsAs (gendersObj) - spark - .createDataset(gendersObj) - .as[(Int, String)] - .collect() must contain theSameElementsAs (gendersStr) - } - - "Creating bytestring dataset" should "work" in { - val byteStrings: Seq[ByteString] = - Seq(ByteString.copyFrom(Array[Byte](1, 2, 3)), ByteString.EMPTY) - val bytesArrays = byteStrings.map(_.toByteArray) - - spark - .createDataset(byteStrings) - .as[Array[Byte]] - .collect() - .toVector must contain theSameElementsAs (bytesArrays) - spark - .createDataset(bytesArrays) - .as[ByteString] - .collect() must contain theSameElementsAs (byteStrings) - - spark - .createDataset(byteStrings) - .map(bs => (bs.toString, bs)) - .show() - } - - "Dataset[Person]" should "work" in { - val ds: Dataset[Person] = spark.createDataset(Seq(TestPerson)) - ds.where($"age" > 30).count() must be(1) - ds.where($"age" > 40).count() must be(0) - ds.where($"gender" === "MALE").count() must be(1) - ds.collect() must be(Array(TestPerson)) - ds.toDF().as[Person].collect() must be(Array(TestPerson)) - ds.select("data").printSchema() - ds.select(F.sha1(F.col("data"))).printSchema() - ds.show() - ds.toDF().printSchema() - } - - "as[SimplePerson]" should "work for manual building" in { - val pl = PersonLike( - name = "Owen M", - age = 35, - addresses = Seq.empty, - gender = "MALE", - inner = Some(InnerLike("V1")), - tags = Seq("foo", "bar"), - address = Some(AddressLike(Some("Main"), Some("Bar"))), - nums = Vector(3, 4, 5) - ) - val p = - SimplePerson().update( - _.name := "Owen M", - _.age := 35, - _.tags := Seq("foo", "bar"), - _.address.street := "Main", - _.address.city := "Bar", - _.nums := Seq(3, 4, 5) - ) - val manualDF: DataFrame = spark.createDataFrame(Seq(pl)) - val manualDS: Dataset[SimplePerson] = spark.createDataset(Seq(p)) - manualDF.as[SimplePerson].collect()(0) must be(p) - manualDS.collect()(0) must be(p) - } - - "as[Person]" should "work for manual building" in { - val pl = PersonLike( - name = "Owen M", - age = 35, - addresses = Seq( - AddressLike(Some("foo"), Some("bar")), - AddressLike(Some("baz"), Some("taz")) - ), - gender = "MALE", - inner = Some(InnerLike("V1")), - data = Some(TestPerson.getData.toByteArray) - ) - val manualDF: DataFrame = spark.createDataFrame(Seq(pl)) - manualDF.show() - manualDF.as[Person].collect()(0) must be( - Person().update( - _.name := "Owen M", - _.age := 35, - _.gender := Gender.MALE, - _.inner.innerValue := InnerEnum.V1, - _.data := ByteString.copyFrom(Array[Byte](1, 2, 3)), - _.addresses := pl.addresses.map(a => Address(city = a.city, street = a.street)) - ) - ) - spark.createDataset(Seq(Person(gender = Some(Gender.FEMALE)))).toDF().show() - } - - "converting from rdd to dataframe" should "work" in { - val rdd = spark.sparkContext.parallelize(Seq(Person(name = Some("foo")))) - rdd - .toDataFrame(spark) - .select($"name") - .collect() - .map(_.getAs[String]("name")) must contain theSameElementsAs (Vector( - "foo" - )) - } - - "selecting message fields into dataset should work" should "work" in { - val df = ProtoSQL.createDataFrame( - spark, - Seq( - TestPerson, - TestPerson.withName("Other").clearAddresses, - TestPerson - .withName("Other2") - .clearData - .clearGender - .clearAddresses - .addAddresses(Address(street = Some("FooBar"))) - ) - ) - - val ds = df.select($"name", $"addresses".getItem(0)) - - ds.as[(String, Option[Address])].collect() must contain theSameElementsAs ( - Seq( - (TestPerson.getName, Some(TestPerson.addresses.head)), - ("Other", None), - ("Other2", Some(Address(street = Some("FooBar")))) - ) - ) - - ds.as[(String, Address)].collect() must contain theSameElementsAs ( - Seq( - (TestPerson.getName, TestPerson.addresses.head), - null, - ("Other2", Address(street = Some("FooBar"))) - ) - ) - - val ds2 = df.select($"name", $"gender") - ds2.as[(String, Option[Gender])].collect() must contain theSameElementsAs ( - Seq( - (TestPerson.getName, Some(Gender.MALE)), - ("Other", Some(Gender.MALE)), - ("Other2", None) - ) - ) - ds2.as[(String, Gender)].collect() must contain theSameElementsAs ( - Seq( - (TestPerson.getName, Gender.MALE), - ("Other", Gender.MALE), - null - ) - ) - - val ds3 = df.select($"name", $"data") - ds3 - .as[(String, Option[ByteString])] - .collect() must contain theSameElementsAs ( - Seq( - (TestPerson.getName, Some(TestPerson.getData)), - ("Other", Some(TestPerson.getData)), - ("Other2", None) - ) - ) - } - - "serialize and deserialize" should "work on dataset of bytes" in { - val s = Seq( - TestPerson.update(_.name := "p1"), - TestPerson.update(_.name := "p2"), - TestPerson.update(_.name := "p3") - ) - val bs: Dataset[Array[Byte]] = spark.createDataset(s).map(_.toByteArray) - bs.map(Person.parseFrom).collect() must contain theSameElementsAs (s) - } - - "UDFs that involve protos" should "work when using ProtoSQL.udfs" in { - val h1 = Hit( - id = Some(ByteString.copyFrom(Array[Byte](112, 75, 6))), - target = Some("foo") - ) - val events: Seq[Event] = - Seq( - Event( - eventId = Some("xyz"), - action = Some(com.google.protobuf.any.Any.pack(h1)) - ) - ) - val df = ProtoSQL.createDataFrame(spark, events) - val parseHit = ProtoSQL.udf { s: Array[Byte] => Hit.parseFrom(s) } - df.withColumn("foo", parseHit($"action.value")).show() - } - - "UDFs that returns protos" should "work when using ProtoSQL.createDataFrame" in { - val h1 = Hit( - id = Some(ByteString.copyFrom(Array[Byte](112, 75, 6))), - target = Some("foo") - ) - - val events: Seq[Address] = - Seq( - Address() - ) - val df = ProtoSQL.createDataFrame(spark, events) - - val returnAddress = ProtoSQL.udf { s: String => Address() } - - df.withColumn("address", returnAddress($"street")) - .write - .mode("overwrite") - .save("/tmp/address1") - } - - "UDFs that returns protos" should "work when reading local files" in { - val df = spark.read.json(getClass.getResource("/address.json").toURI.toString) - - val returnAddress = ProtoSQL.udf { s: String => Address() } - - df.withColumn("address", returnAddress($"foo")) - .write - .mode("overwrite") - .save("/tmp/address2") - } - - "OuterCaseClass" should "use our type encoders" in { - val outer = OuterCaseClass(TestPerson, "foo") - val df = spark.createDataset(Seq(outer)).toDF() - df.select($"x.*").as[Person].collect() must contain theSameElementsAs (Seq(TestPerson)) - } - - "OuterCaseClassTimestamp" should "serialize a java.sql.Timestamp" in { - implicit val timestampInjection = new frameless.Injection[Timestamp, frameless.SQLTimestamp] { - def apply(ts: Timestamp): frameless.SQLTimestamp = { - val i = ts.toInstant() - frameless.SQLTimestamp(i.getEpochSecond() * 1000000 + i.getNano() / 1000) - } - - def invert(l: frameless.SQLTimestamp): Timestamp = Timestamp.from( - java.time.Instant.EPOCH.plus(l.us, java.time.temporal.ChronoUnit.MICROS) - ) - } - - val ts = Timestamp.valueOf("2020-11-17 21:34:56.157") - val outer = OuterCaseClassTimestamp(TestPerson, ts) - val df = spark.createDataset(Seq(outer)).toDF() - - df.select($"x.*").as[Person].collect() must contain theSameElementsAs (Seq(TestPerson)) - df.select($"y").as[Timestamp].collect() must contain theSameElementsAs Seq(ts) - } - - "parsing null repeated from json" should "work" in { - spark.read - .schema(ProtoSQL.schemaFor[Person].asInstanceOf[types.StructType]) - .json(getClass.getResource("/person_null_repeated.json").toURI.toString) - .as[Person] - .collect() must contain theSameElementsAs Seq( - Person().withTags(Seq("foo", "bar")), - Person(), - Person(), - Person() - ) - } -} diff --git a/sparksql-scalapb/src/test/scala/RepeatedSchemaSpec.scala b/sparksql-scalapb/src/test/scala/RepeatedSchemaSpec.scala deleted file mode 100644 index b458e96..0000000 --- a/sparksql-scalapb/src/test/scala/RepeatedSchemaSpec.scala +++ /dev/null @@ -1,154 +0,0 @@ -package scalapb.spark - -import org.apache.spark.sql.SparkSession -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers -import scalapb.spark.schema.bug.schema_bug.{ - Nested, - Read, - RepeatedNestedRead, - RepeatedNestedWrite, - RepeatedOmitWrite, - Write -} - -// See https://github.com/scalapb/sparksql-scalapb/issues/313 -class RepeatedSchemaSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .config("spark.ui.enabled", "false") - .getOrCreate() - - "Read from data with extra columns" should "work" in { - val data = Seq( - Write( - nestedField = Seq( - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_11", - fieldTwo = "field_two_repeated_11" - ) - ), - fieldOne = "field_one_11", - fieldTwo = "field_two_11", - persisted = Some(Nested(fieldOne = "field_one_nested_11")) - ), - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_21", - fieldTwo = "field_two_repeated_21" - ) - ), - fieldOne = "field_one_21", - fieldTwo = "field_two_21", - persisted = Some(Nested(fieldOne = "field_one_nested_21")) - ), - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_31", - fieldTwo = "field_two_repeated_31" - ) - ), - fieldOne = "field_one_31", - fieldTwo = "field_two_31", - persisted = Some(Nested(fieldOne = "field_one_nested_31")) - ) - ), - additionalField = "additional_1" - ), - Write( - nestedField = Seq( - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_12", - fieldTwo = "field_two_repeated_12" - ) - ), - fieldOne = "field_one_12", - fieldTwo = "field_two_12", - persisted = Some(Nested(fieldOne = "field_one_nested_12")) - ), - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_22", - fieldTwo = "field_two_repeated_22" - ) - ), - fieldOne = "field_one_22", - fieldTwo = "field_two_22", - persisted = Some(Nested(fieldOne = "field_one_nested_22")) - ), - RepeatedNestedWrite( - omitForRead = Some( - RepeatedOmitWrite( - fieldOne = "field_one_repeated_32", - fieldTwo = "field_two_repeated_32" - ) - ), - fieldOne = "field_one_32", - fieldTwo = "field_two_32", - persisted = Some(Nested(fieldOne = "field_one_nested_32")) - ) - ), - additionalField = "additional_2" - ) - ) - - import ProtoSQL.implicits._ - val path = "/tmp/repeated-nested-bug" - spark.createDataset(data).write.mode("overwrite").parquet(path) - val readDf = spark.read.parquet(path) - val readDs = readDf.as[Read] - val readExpected = Seq( - Read( - nestedField = Seq( - RepeatedNestedRead( - fieldOne = "field_one_11", - fieldTwo = "field_two_11", - persisted = Some(Nested(fieldOne = "field_one_nested_11")) - ), - RepeatedNestedRead( - fieldOne = "field_one_21", - fieldTwo = "field_two_21", - persisted = Some(Nested(fieldOne = "field_one_nested_21")) - ), - RepeatedNestedRead( - fieldOne = "field_one_31", - fieldTwo = "field_two_31", - persisted = Some(Nested(fieldOne = "field_one_nested_31")) - ) - ), - additionalField = "additional_1" - ), - Read( - nestedField = Seq( - RepeatedNestedRead( - fieldOne = "field_one_12", - fieldTwo = "field_two_12", - persisted = Some(Nested(fieldOne = "field_one_nested_12")) - ), - RepeatedNestedRead( - fieldOne = "field_one_22", - fieldTwo = "field_two_22", - persisted = Some(Nested(fieldOne = "field_one_nested_22")) - ), - RepeatedNestedRead( - fieldOne = "field_one_32", - fieldTwo = "field_two_32", - persisted = Some(Nested(fieldOne = "field_one_nested_32")) - ) - ), - additionalField = "additional_2" - ) - ) - readDs.collect() must contain theSameElementsAs readExpected - } -} diff --git a/sparksql-scalapb/src/test/scala/SchemaOptionsSpec.scala b/sparksql-scalapb/src/test/scala/SchemaOptionsSpec.scala deleted file mode 100644 index dd53711..0000000 --- a/sparksql-scalapb/src/test/scala/SchemaOptionsSpec.scala +++ /dev/null @@ -1,106 +0,0 @@ -package scalapb.spark - -import com.example.protos.wrappers._ -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.ArrayType -import org.apache.spark.sql.types.StructField -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.Row - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers - -class SchemaOptionsSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .getOrCreate() - - import spark.implicits.StringToColumn - - val data = Seq( - PrimitiveWrappers( - intValue = Option(45), - stringValue = Option("boo"), - ints = Seq(17, 19, 25), - strings = Seq("foo", "bar") - ), - PrimitiveWrappers( - intValue = None, - stringValue = None, - ints = Seq(17, 19, 25), - strings = Seq("foo", "bar") - ) - ) - - "converting df with primitive wrappers" should "unpack primitive wrappers by default" in { - import ProtoSQL.implicits._ - val df = ProtoSQL.createDataFrame(spark, data) - df.schema.fields.map(_.dataType).toSeq must be( - Seq( - IntegerType, - StringType, - ArrayType(IntegerType, true), - ArrayType(StringType, true) - ) - ) - df.collect() must contain theSameElementsAs ( - Seq( - Row(45, "boo", Seq(17, 19, 25), Seq("foo", "bar")), - Row(null, null, Seq(17, 19, 25), Seq("foo", "bar")) - ) - ) - } - - "converting df with primitive wrappers" should "retain value field when option is set" in { - import ProtoSQL.withRetainedPrimitiveWrappers.implicits._ - val df = ProtoSQL.withRetainedPrimitiveWrappers.createDataFrame(spark, data) - df.schema.fields.map(_.dataType).toSeq must be( - Seq( - StructType(Seq(StructField("value", IntegerType, true))), - StructType(Seq(StructField("value", StringType, true))), - ArrayType( - StructType(Seq(StructField("value", IntegerType, true))), - true - ), - ArrayType( - StructType(Seq(StructField("value", StringType, true))), - true - ) - ) - ) - df.collect() must contain theSameElementsAs ( - Seq( - Row( - Row(45), - Row("boo"), - Seq(Row(17), Row(19), Row(25)), - Seq(Row("foo"), Row("bar")) - ), - Row( - null, - null, - Seq(Row(17), Row(19), Row(25)), - Seq(Row("foo"), Row("bar")) - ) - ) - ) - } - - "schema" should "use scalaNames when option is set" in { - val scalaNameProtoSQL = new ProtoSQL(SchemaOptions.Default.withScalaNames) - import scalaNameProtoSQL.implicits._ - val df = scalaNameProtoSQL.createDataFrame(spark, data) - df.schema.fieldNames.toVector must contain theSameElementsAs (Seq( - "intValue", - "stringValue", - "ints", - "strings" - )) - df.collect() - } -} diff --git a/sparksql-scalapb/src/test/scala/TimestampSpec.scala b/sparksql-scalapb/src/test/scala/TimestampSpec.scala deleted file mode 100644 index cadf9a2..0000000 --- a/sparksql-scalapb/src/test/scala/TimestampSpec.scala +++ /dev/null @@ -1,214 +0,0 @@ -package scalapb.spark - -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.types._ -import org.scalatest.BeforeAndAfterAll -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.must.Matchers -import scalapb.spark.test3.customizations.{ - BothTimestampTypes, - SQLTimestampFromGoogleTimestamp, - StructFromGoogleTimestamp, - TimestampTypesMap -} - -import java.sql.{Timestamp => SQLTimestamp} -import com.google.protobuf.timestamp.{Timestamp => GoogleTimestamp} - -import java.time.Instant - -case class TestTimestampsHolder( - justLong: Long, - ts: SQLTimestamp, - bothTimestampTypes: BothTimestampTypes -) - -class TimestampSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { - val spark: SparkSession = SparkSession - .builder() - .appName("ScalaPB Demo") - .master("local[2]") - .config("spark.ui.enabled", "false") - .getOrCreate() - - val data = Seq( - TestTimestampsHolder( - 0, - new SQLTimestamp(1), - BothTimestampTypes( - // 2 seconds + 3 milliseconds + 4 microseconds + 5 nanoseconds - googleTs = Some(GoogleTimestamp(2, 3 * 1000000 + 4 * 1000 + 5)), - googleTsAsSqlTs = Some({ - // 5 seconds - val ts = new SQLTimestamp(5 * 1000) - // + 6 milliseconds + 7 microseconds + 8 nanoseconds - ts.setNanos(6 * 1000000 + 7 * 1000 + 8) - ts - }) - ) - ) - ) - - // 4 seconds + 5 microseconds + 6 nanoseconds - val googleTimestampNanosPrecision = GoogleTimestamp(4, 5 * 1000 + 6) - val googleTimestampMicrosPrecision = GoogleTimestamp( - googleTimestampNanosPrecision.seconds, - googleTimestampNanosPrecision.nanos / 1000 * 1000 - ) - val sqlTimestampMicrosPrecision = SQLTimestamp.from( - Instant.ofEpochSecond( - googleTimestampMicrosPrecision.seconds, - googleTimestampMicrosPrecision.nanos - ) - ) - - val protoMessagesWithGoogleTimestamp = Seq( - StructFromGoogleTimestamp( - googleTs = Some(googleTimestampNanosPrecision) - ), - StructFromGoogleTimestamp( - googleTs = Some(googleTimestampNanosPrecision) - ) - ) - - val protoMessagesWithGoogleTimestampMappedToSQLTimestamp = Seq( - SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)), - SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)) - ) - - "ProtoSQL.createDataFrame from proto messages with google timestamp" should "have a spark schema field type of TimestampType" in { - val df: DataFrame = - ProtoSQL.withSparkTimestamps.createDataFrame(spark, protoMessagesWithGoogleTimestamp) - df.schema.fields.map(_.dataType).toSeq must be( - Seq( - TimestampType - ) - ) - } - - "ProtoSQL.createDataFrame from proto messages with google timestamp" should "be able to collect items with microsecond timestamp precision" in { - val df: DataFrame = - ProtoSQL.withSparkTimestamps.createDataFrame(spark, protoMessagesWithGoogleTimestamp) - - df.collect().map(_.toSeq) must contain theSameElementsAs Seq( - Seq(sqlTimestampMicrosPrecision), - Seq(sqlTimestampMicrosPrecision) - ) - } - - "spark.createDataset from proto messages with google timestamp" should "have a spark schema field type of TimestampType" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val ds: Dataset[StructFromGoogleTimestamp] = - spark.createDataset(protoMessagesWithGoogleTimestamp) - ds.schema.fields.map(_.dataType).toSeq must be( - Seq( - TimestampType - ) - ) - } - - "spark.createDataset from proto messages with google timestamp" should "be able to collect items with correct timestamp values" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val ds: Dataset[StructFromGoogleTimestamp] = - spark.createDataset(protoMessagesWithGoogleTimestamp) - ds.collect() must contain theSameElementsAs Seq( - StructFromGoogleTimestamp(googleTs = Some(googleTimestampMicrosPrecision)), - StructFromGoogleTimestamp(googleTs = Some(googleTimestampMicrosPrecision)) - ) - } - - "spark.createDataset from proto messages with google timestamp" should "be able to convert items with correct timestamp values" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val ds: Dataset[StructFromGoogleTimestamp] = - spark.createDataset(protoMessagesWithGoogleTimestamp) - - val dsMapped: Dataset[StructFromGoogleTimestamp] = ds.map(record => record) - - dsMapped.collect() must contain theSameElementsAs Seq( - StructFromGoogleTimestamp(googleTs = Some(googleTimestampMicrosPrecision)), - StructFromGoogleTimestamp(googleTs = Some(googleTimestampMicrosPrecision)) - ) - } - - "spark.createDataset from proto messages with spark timestamp" should "have a spark schema field type of TimestampType" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val ds: Dataset[SQLTimestampFromGoogleTimestamp] = - spark.createDataset(protoMessagesWithGoogleTimestampMappedToSQLTimestamp) - ds.schema.fields.map(_.dataType).toSeq must be( - Seq( - TimestampType - ) - ) - } - - "spark.createDataset from proto messages with spark timestamp" should "be able to convert items with correct timestamp values" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val ds: Dataset[SQLTimestampFromGoogleTimestamp] = - spark.createDataset(protoMessagesWithGoogleTimestampMappedToSQLTimestamp) - - val dsMapped = ds.map(record => record) - - dsMapped.collect() must contain theSameElementsAs Seq( - SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)), - SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)) - ) - } - - "spark.createDataset from proto messages with spark timestamp in map" should "be able to convert items with correct timestamp values" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - val value = TimestampTypesMap(mapField = - Map( - "a" -> SQLTimestampFromGoogleTimestamp(googleTsAsSqlTs = Some(sqlTimestampMicrosPrecision)) - ) - ) - val ds: Dataset[TimestampTypesMap] = spark.createDataset(Seq(value)) - - ds.collect() must contain theSameElementsAs Seq( - value - ) - } - - "df with case class timestamp as well as both types of google timestamp" should "not have StructType for timestamps" in { - import ProtoSQL.withSparkTimestamps.implicits._ - - implicit val timestampInjection = - new frameless.Injection[SQLTimestamp, frameless.SQLTimestamp] { - def apply(ts: SQLTimestamp): frameless.SQLTimestamp = { - val i = ts.toInstant() - frameless.SQLTimestamp(i.getEpochSecond() * 1000000 + i.getNano() / 1000) - } - - def invert(l: frameless.SQLTimestamp): SQLTimestamp = SQLTimestamp.from( - java.time.Instant.EPOCH.plus(l.us, java.time.temporal.ChronoUnit.MICROS) - ) - } - - val ds: Dataset[TestTimestampsHolder] = spark.createDataset(data) - - ds.schema must be( - StructType( - Seq( - StructField("justLong", LongType), - StructField("ts", TimestampType), - StructField( - "bothTimestampTypes", - StructType( - Seq( - StructField("google_ts", TimestampType), - StructField("google_ts_as_sql_ts", TimestampType) - ) - ) - ) - ) - ) - ) - - } - -}