Skip to content

Commit 7c70d30

Browse files
committed
Switch spark connector to use Java client6
1 parent 8502763 commit 7c70d30

18 files changed

+600
-715
lines changed

.github/workflows/create-release.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
uses: actions/setup-java@v3
1515
with:
1616
distribution: temurin
17-
java-version: 11
17+
java-version: 17
1818
- name: Run integration tests to ensure release works
1919
run: |
2020
sbt -v +test
@@ -25,7 +25,6 @@ jobs:
2525
- name: Create jar with all dependencies included
2626
run: |
2727
sbt +assembly
28-
test -f ./target/scala-2.12/spark-connector-assembly-${{ env.version }}.jar
2928
test -f ./target/scala-2.13/spark-connector-assembly-${{ env.version }}.jar
3029
- name: Publish artifact to sonatype and release to maven repo
3130
env:
@@ -37,5 +36,5 @@ jobs:
3736
with:
3837
generateReleasenotes: true
3938
name: Release ${{ env.version }}
40-
artifacts: "./target/scala-2.12/spark-connector-assembly-${{ env.version }}.jar,./target/scala-2.13/spark-connector-assembly-${{ env.version }}.jar"
39+
artifacts: "./target/scala-2.13/spark-connector-assembly-${{ env.version }}.jar"
4140
artifactContentType: application/java-archive

.github/workflows/on-pull-request.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
matrix:
1313
distribution: [ 'zulu', 'temurin' ]
14-
java: [ '8', '11' ]
14+
java: [ '17', '21' ]
1515
os: [ 'ubuntu-latest' ]
1616
steps:
1717
- uses: actions/checkout@v4
@@ -34,7 +34,7 @@ jobs:
3434
uses: actions/setup-java@v4
3535
with:
3636
distribution: temurin
37-
java-version: 11
37+
java-version: 17
3838
- uses: sbt/setup-sbt@v1
3939
- name: Get release version
4040
run: |
@@ -43,7 +43,6 @@ jobs:
4343
- name: Create jar with all dependencies included
4444
run: |
4545
sbt +assembly
46-
test -f ./target/scala-2.12/spark-connector-assembly-${{ env.version }}.jar
4746
test -f ./target/scala-2.13/spark-connector-assembly-${{ env.version }}.jar
4847
- uses: actions/setup-python@v4
4948
with:

.sbtopts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
-J--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
1111
-J--add-opens=java.base/sun.nio.cs=ALL-UNNAMED
1212
-J--add-opens=java.base/sun.security.action=ALL-UNNAMED
13-
-J--add-opens=java.base/sun.util.calendar=ALL-UNNAMED
13+
-J--add-opens=java.base/sun.util.calendar=ALL-UNNAMED
14+
-J--add-exports=java.base/sun.nio.ch=ALL-UNNAMED

build.sbt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import ReleaseTransformations._
22

33
resolvers += Resolver.mavenLocal
44

5-
ThisBuild / scalaVersion := "2.12.18"
5+
ThisBuild / scalaVersion := "2.13.17"
66

7-
crossScalaVersions := Seq("2.12.18", "2.13.12")
7+
javaOptions ++= Seq("-target", "17", "-source", "17")
88

99
lazy val root = (project in file("."))
1010
.settings(
@@ -14,9 +14,9 @@ lazy val root = (project in file("."))
1414

1515
ThisBuild / scalafixDependencies += "org.scalalint" %% "rules" % "0.2.1" % "runtime"
1616

17-
lazy val sparkVersion = "3.5.5"
17+
lazy val sparkVersion = "4.0.1"
1818
lazy val grpcNettyShadedVersion = "1.72.0"
19-
lazy val weaviateClientVersion = "5.2.1"
19+
lazy val weaviateClient6Version = "6.0.0-M2"
2020
lazy val scalaCollectionCompatVersion = "2.13.0"
2121
lazy val scalatestVersion = "3.2.19"
2222
lazy val gsonVersion = "2.13.1"
@@ -26,7 +26,7 @@ libraryDependencies ++= Seq(
2626
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided,test",
2727
"org.apache.spark" %% "spark-catalyst" % sparkVersion % "provided,test",
2828
"org.scala-lang.modules" %% "scala-collection-compat" % scalaCollectionCompatVersion,
29-
"io.weaviate" % "client" % weaviateClientVersion,
29+
"io.weaviate" % "client6" % weaviateClient6Version,
3030
"io.grpc" % "grpc-netty-shaded" % grpcNettyShadedVersion,
3131
"com.google.code.gson" % "gson" % gsonVersion
3232
)

project/build.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
sbt.version = 1.7.3
1+
sbt.version = 1.11.7

project/plugins.sbt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
addSbtPlugin("org.jetbrains.scala" % "sbt-ide-settings" % "1.1.1")
2-
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.8")
3-
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0")
4-
addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.2.1")
5-
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.15")
6-
addSbtPlugin("com.github.sbt" % "sbt-release" % "1.1.0")
1+
addSbtPlugin("org.jetbrains.scala" % "sbt-ide-settings" % "1.1.3")
2+
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.3")
3+
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1")
4+
addSbtPlugin("com.github.sbt" % "sbt-pgp" % "2.3.1")
5+
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.11.0")
6+
addSbtPlugin("com.github.sbt" % "sbt-release" % "1.4.0")

src/main/scala/Utils.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package io.weaviate.spark
22

3-
import io.weaviate.client.v1.schema.model.Property
3+
import io.weaviate.client6.v1.api.collections.Property
44
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
55

6-
import scala.collection.JavaConverters._
6+
import scala.jdk.CollectionConverters._
77
import java.util
88

99
object Utils {
10-
def weaviateToSparkDatatype(datatype: util.List[String], nestedProperties: util.List[Property.NestedProperty]): DataType = {
10+
def weaviateToSparkDatatype(datatype: util.List[String], nestedProperties: util.List[Property]): DataType = {
1111
datatype.get(0) match {
1212
case "string" => DataTypes.StringType
1313
case "string[]" => DataTypes.createArrayType(DataTypes.StringType)
@@ -27,9 +27,9 @@ object Utils {
2727
}
2828
}
2929

30-
private def createStructType(nestedProperties: util.List[Property.NestedProperty]): StructType = {
30+
private def createStructType(nestedProperties: util.List[Property]): StructType = {
3131
val fields = nestedProperties.asScala.map(prop => {
32-
StructField(name = prop.getName, dataType = weaviateToSparkDatatype(prop.getDataType, prop.getNestedProperties))
32+
StructField(name = prop.propertyName(), dataType = weaviateToSparkDatatype(prop.dataTypes(), prop.nestedProperties()))
3333
}).asJava
3434

3535
DataTypes.createStructType(fields)

src/main/scala/Weaviate.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,17 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
99
import java.util
1010
import scala.jdk.CollectionConverters._
1111

12-
13-
1412
class Weaviate extends TableProvider with DataSourceRegister {
1513
override def shortName(): String = "weaviate"
1614
override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
1715
val weaviateOptions = new WeaviateOptions(options)
1816
val client = weaviateOptions.getClient()
1917
val className = weaviateOptions.className
20-
val result = client.schema.classGetter.withClassName(className).run
21-
if (result.hasErrors) throw new WeaviateResultError(result.getError.getMessages.toString)
22-
if (result.getResult == null) throw new WeaviateClassNotFoundError("Class "+className+ " was not found.")
23-
val properties = result.getResult.getProperties.asScala
18+
val result = client.collections.getConfig(className)
19+
if (result.isEmpty) throw WeaviateClassNotFoundError("Class "+className+ " was not found.")
20+
val properties = result.get().properties().asScala
2421
val structFields = properties.map(p =>
25-
StructField(p.getName(), Utils.weaviateToSparkDatatype(p.getDataType, p.getNestedProperties), true, Metadata.empty))
22+
StructField(p.propertyName(), Utils.weaviateToSparkDatatype(p.dataTypes(), p.nestedProperties()), true, Metadata.empty))
2623
if (weaviateOptions.vector != null)
2724
structFields.append(StructField(weaviateOptions.vector, DataTypes.createArrayType(DataTypes.FloatType), true, Metadata.empty))
2825
if (weaviateOptions.id != null)

src/main/scala/WeaviateDataWriter.scala

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package io.weaviate.spark
22

33
import com.google.gson.reflect.TypeToken
44
import com.google.gson.{Gson, JsonSyntaxException}
5+
import io.weaviate.client6.v1.api.collections.data.Reference
56
import org.apache.spark.internal.Logging
67
import org.apache.spark.sql.catalyst.InternalRow
78
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
89
import org.apache.spark.sql.types._
9-
import io.weaviate.client.v1.data.model.WeaviateObject
10-
import io.weaviate.client.v1.schema.model.WeaviateClass
10+
import io.weaviate.client6.v1.api.collections.{CollectionConfig, ObjectMetadata, Vectors, WeaviateObject}
1111
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
1212

1313
import java.util.{Map => JavaMap}
@@ -20,12 +20,12 @@ case class WeaviateCommitMessage(msg: String) extends WriterCommitMessage
2020

2121
case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructType)
2222
extends DataWriter[InternalRow] with Serializable with Logging {
23-
var batch = mutable.Map[String, WeaviateObject]()
24-
private val weaviateClass = weaviateOptions.getWeaviateClass()
23+
var batch = mutable.Map[String, WeaviateObject[JavaMap[String, Object], Reference, ObjectMetadata]]()
24+
private lazy val weaviateClass = weaviateOptions.getCollectionConfig()
2525

2626
override def write(record: InternalRow): Unit = {
2727
val weaviateObject = buildWeaviateObject(record, weaviateClass)
28-
batch += (weaviateObject.getId -> weaviateObject)
28+
batch += (weaviateObject.uuid() -> weaviateObject)
2929

3030
if (batch.size >= weaviateOptions.batchSize) writeBatch()
3131
}
@@ -36,38 +36,35 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
3636
val consistencyLevel = weaviateOptions.consistencyLevel
3737
val client = weaviateOptions.getClient()
3838

39-
val results = if (consistencyLevel != "") {
40-
logInfo(s"Writing using consistency level: ${consistencyLevel}")
41-
client.batch().objectsBatcher().withObjects(batch.values.toList: _*).withConsistencyLevel(consistencyLevel).run()
42-
} else {
43-
client.batch().objectsBatcher().withObjects(batch.values.toList: _*).run()
44-
}
39+
val collection = client.collections.use(weaviateOptions.className).withTenant(weaviateOptions.tenant).withConsistencyLevel(consistencyLevel)
40+
41+
val results = collection.data.insertMany(batch.values.toList.asJava)
4542

4643
val IDs = batch.keys.toList
4744

48-
if (results.hasErrors || results.getResult == null) {
45+
if (results.errors() != null && !results.errors().isEmpty) {
4946
if (retries == 0) {
5047
throw WeaviateResultError(s"error getting result and no more retries left." +
51-
s" Error from Weaviate: ${results.getError.getMessages}")
48+
s" Error from Weaviate: ${results.errors().asScala.mkString(",")}")
5249
}
5350
if (retries > 0) {
54-
logError(s"batch error: ${results.getError.getMessages}, will retry")
51+
logError(s"batch error: ${results.errors().asScala.mkString(",")}, will retry")
5552
logInfo(s"Retrying batch in ${weaviateOptions.retriesBackoff} seconds. Batch has following IDs: ${IDs}")
5653
Thread.sleep(weaviateOptions.retriesBackoff * 1000)
5754
writeBatch(retries - 1)
5855
}
5956
} else {
60-
val (objectsWithSuccess, objectsWithError) = results.getResult.partition(_.getResult.getErrors == null)
57+
val (objectsWithSuccess, objectsWithError) = results.responses().asScala.partition(_.error() == null)
6158
if (objectsWithError.size > 0 && retries > 0) {
62-
val errors = objectsWithError.map(obj => s"${obj.getId}: ${obj.getResult.getErrors.toString}")
63-
val successIDs = objectsWithSuccess.map(_.getId).toList
59+
val errors = objectsWithError.map(obj => s"${obj.uuid()}: ${obj.error()}")
60+
val successIDs = objectsWithSuccess.map(_.uuid()).toList
6461
logWarning(s"Successfully imported ${successIDs}. " +
6562
s"Retrying objects with an error. Following objects in the batch upload had an error: ${errors.mkString("Array(", ", ", ")")}")
6663
batch = batch -- successIDs
6764
writeBatch(retries - 1)
6865
} else if (objectsWithError.size > 0) {
69-
val errorIds = objectsWithError.map(obj => obj.getId)
70-
val errorMessages = objectsWithError.map(obj => obj.getResult.getErrors.toString).distinct
66+
val errorIds = objectsWithError.map(obj => obj.uuid())
67+
val errorMessages = objectsWithError.map(obj => obj.error()).distinct
7168
throw WeaviateResultError(s"Error writing to weaviate and no more retries left." +
7269
s" IDs with errors: ${errorIds.mkString("Array(", ", ", ")")}." +
7370
s" Error messages: ${errorMessages.mkString("Array(", ", ", ")")}")
@@ -79,17 +76,19 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
7976
}
8077
}
8178

82-
private[spark] def buildWeaviateObject(record: InternalRow, weaviateClass: WeaviateClass = null): WeaviateObject = {
83-
var builder = WeaviateObject.builder.className(weaviateOptions.className)
84-
if (weaviateOptions.tenant != null) {
85-
builder = builder.tenant(weaviateOptions.tenant)
86-
}
79+
private[spark] def buildWeaviateObject(record: InternalRow, collectionConfig: CollectionConfig = null): WeaviateObject[java.util.Map[String, Object], Reference, ObjectMetadata] = {
80+
var builder: WeaviateObject.Builder[java.util.Map[String, Object], Reference, ObjectMetadata] = new WeaviateObject.Builder()
81+
82+
builder = builder.collection(weaviateOptions.className)
83+
84+
var id: String = null
8785
val properties = mutable.Map[String, AnyRef]()
86+
var vector: Array[Float] = null
8887
val vectors = mutable.Map[String, Array[Float]]()
8988
val multiVectors = mutable.Map[String, Array[Array[Float]]]()
9089
schema.zipWithIndex.foreach(field =>
9190
field._1.name match {
92-
case weaviateOptions.vector => builder = builder.vector(record.getArray(field._2).toArray(FloatType))
91+
case weaviateOptions.vector => vector = record.getArray(field._2).toArray(FloatType)
9392
case key if weaviateOptions.vectors.contains(key) => vectors += (weaviateOptions.vectors(key) -> record.getArray(field._2).toArray(FloatType))
9493
case key if weaviateOptions.multiVectors.contains(key) => {
9594
val multiVectorArrayData = record.get(field._2, ArrayType(ArrayType(FloatType))) match {
@@ -105,31 +104,43 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
105104

106105
multiVectors += (weaviateOptions.multiVectors(key) -> multiVector)
107106
}
108-
case weaviateOptions.id => builder = builder.id(record.getString(field._2))
109-
case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false, field._1.name, weaviateClass)
107+
case weaviateOptions.id => id = record.getString(field._2)//builder.id(record.getString(field._2))
108+
case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false, field._1.name, collectionConfig)
110109
}
111110
)
111+
val metadata = new ObjectMetadata.Builder()
112+
112113
if (weaviateOptions.id == null) {
113-
builder.id(java.util.UUID.randomUUID.toString)
114+
metadata.uuid(java.util.UUID.randomUUID.toString)
115+
} else {
116+
metadata.uuid(id)
114117
}
115118

119+
val allvectors = ListBuffer.empty[Vectors]
120+
if (vector != null) {
121+
allvectors += Vectors.of(vector)
122+
}
116123
if (vectors.nonEmpty) {
117-
builder.vectors(vectors.map { case (key, arr) => key -> arr.map(Float.box) }.asJava)
124+
val arr = vectors.map { case (key, arr) => Vectors.of(key, arr) }.toArray
125+
allvectors ++= arr
118126
}
119127
if (multiVectors.nonEmpty) {
120-
builder.multiVectors(multiVectors.map { case (key, multiVector) => key -> multiVector.map { vec => { vec.map(Float.box) }} }.toMap.asJava)
128+
val arr = multiVectors.map { case (key, multiVector) => Vectors.of(key, multiVector) }.toArray
129+
allvectors ++= arr
121130
}
122-
builder.properties(properties.asJava).build
131+
metadata.vectors(allvectors.toSeq : _*)
132+
133+
builder.properties(properties.asJava).metadata(metadata.build()).build()
123134
}
124135

125-
def getPropertyValue(index: Int, record: InternalRow, dataType: DataType, parseObjectArrayItem: Boolean, propertyName: String, weaviateClass: WeaviateClass): AnyRef = {
136+
def getPropertyValue(index: Int, record: InternalRow, dataType: DataType, parseObjectArrayItem: Boolean, propertyName: String, collectionConfig: CollectionConfig): AnyRef = {
126137
val valueFromField = getValueFromField(index, record, dataType, parseObjectArrayItem)
127-
if (weaviateClass != null) {
138+
if (collectionConfig != null) {
128139
var dt = ""
129-
weaviateClass.getProperties.forEach(p => {
130-
if (p.getName == propertyName) {
140+
collectionConfig.properties().forEach(p => {
141+
if (p.propertyName() == propertyName) {
131142
// we are just looking for geoCoordinates or phoneNumber type
132-
dt = p.getDataType.get(0)
143+
dt = p.dataTypes().get(0)
133144
}
134145
})
135146
if ((dt == "geoCoordinates" || dt == "phoneNumber") && valueFromField.isInstanceOf[String]) {

0 commit comments

Comments
 (0)