Skip to content

Commit 76e405b

Browse files
committed
switch to Client6 RC2 version
1 parent e275d56 commit 76e405b

File tree

7 files changed

+39
-40
lines changed

7 files changed

+39
-40
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ThisBuild / scalafixDependencies += "org.scalalint" %% "rules" % "0.2.1" % "runt
1616

1717
lazy val sparkVersion = "4.0.1"
1818
lazy val grpcNettyShadedVersion = "1.76.0"
19-
lazy val weaviateClient6Version = "6.0.0-M2"
19+
lazy val weaviateClient6Version = "6.0.0-RC2"
2020
lazy val scalaCollectionCompatVersion = "2.14.0"
2121
lazy val scalatestVersion = "3.2.19"
2222
lazy val gsonVersion = "2.13.2"

src/main/scala/WeaviateDataWriter.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package io.weaviate.spark
22

33
import com.google.gson.reflect.TypeToken
44
import com.google.gson.{Gson, JsonSyntaxException}
5-
import io.weaviate.client6.v1.api.collections.data.Reference
5+
import io.weaviate.client6.v1.api.collections.WeaviateObject
66
import org.apache.spark.internal.Logging
77
import org.apache.spark.sql.catalyst.InternalRow
88
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
99
import org.apache.spark.sql.types._
10-
import io.weaviate.client6.v1.api.collections.{CollectionConfig, ObjectMetadata, Vectors, WeaviateObject}
10+
import io.weaviate.client6.v1.api.collections.{CollectionConfig, Vectors}
1111
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
1212

1313
import java.util.{Map => JavaMap}
@@ -20,23 +20,25 @@ case class WeaviateCommitMessage(msg: String) extends WriterCommitMessage
2020

2121
case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructType)
2222
extends DataWriter[InternalRow] with Serializable with Logging {
23-
var batch = mutable.Map[String, WeaviateObject[JavaMap[String, Object], Reference, ObjectMetadata]]()
23+
var batch = mutable.Map[String, WeaviateObject[JavaMap[String, Object]]]()
2424
private lazy val weaviateClass = weaviateOptions.getCollectionConfig()
2525

2626
override def write(record: InternalRow): Unit = {
2727
val weaviateObject = buildWeaviateObject(record, weaviateClass)
28-
batch += (weaviateObject.uuid() -> weaviateObject)
28+
batch += (weaviateObject.uuid -> weaviateObject)
2929

3030
if (batch.size >= weaviateOptions.batchSize) writeBatch()
3131
}
3232

3333
def writeBatch(retries: Int = weaviateOptions.retries): Unit = {
3434
if (batch.isEmpty) return
3535

36-
val consistencyLevel = weaviateOptions.consistencyLevel
3736
val client = weaviateOptions.getClient()
3837

39-
val collection = client.collections.use(weaviateOptions.className).withTenant(weaviateOptions.tenant).withConsistencyLevel(consistencyLevel)
38+
val collection = client.collections
39+
.use(weaviateOptions.className)
40+
.withTenant(weaviateOptions.tenant)
41+
.withConsistencyLevel(weaviateOptions.consistencyLevel)
4042

4143
val results = collection.data.insertMany(batch.values.toList.asJava)
4244

@@ -76,11 +78,9 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
7678
}
7779
}
7880

79-
private[spark] def buildWeaviateObject(record: InternalRow, collectionConfig: CollectionConfig = null): WeaviateObject[java.util.Map[String, Object], Reference, ObjectMetadata] = {
80-
var builder: WeaviateObject.Builder[java.util.Map[String, Object], Reference, ObjectMetadata] = new WeaviateObject.Builder()
81-
builder = builder.collection(weaviateOptions.className)
81+
private[spark] def buildWeaviateObject(record: InternalRow, collectionConfig: CollectionConfig = null): WeaviateObject[java.util.Map[String, Object]] = {
82+
val builder: WeaviateObject.Builder[java.util.Map[String, Object]] = new WeaviateObject.Builder()
8283

83-
val metadata = new ObjectMetadata.Builder()
8484
val properties = mutable.Map[String, AnyRef]()
8585
var vector: Array[Float] = null
8686
val vectors = mutable.Map[String, Array[Float]]()
@@ -103,13 +103,13 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
103103

104104
multiVectors += (weaviateOptions.multiVectors(key) -> multiVector)
105105
}
106-
case weaviateOptions.id => metadata.uuid(record.getString(field._2))
106+
case weaviateOptions.id => builder.uuid(record.getString(field._2))
107107
case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false, field._1.name, collectionConfig)
108108
}
109109
)
110110

111111
if (weaviateOptions.id == null) {
112-
metadata.uuid(java.util.UUID.randomUUID.toString)
112+
builder.uuid(java.util.UUID.randomUUID.toString)
113113
}
114114

115115
val allVectors = ListBuffer.empty[Vectors]
@@ -124,9 +124,8 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
124124
val arr = multiVectors.map { case (key, multiVector) => Vectors.of(key, multiVector) }.toArray
125125
allVectors ++= arr
126126
}
127-
metadata.vectors(allVectors.toSeq : _*)
128127

129-
builder.properties(properties.asJava).metadata(metadata.build()).build()
128+
builder.tenant(weaviateOptions.tenant).properties(properties.asJava).vectors(allVectors.toSeq : _*).build()
130129
}
131130

132131
def getPropertyValue(index: Int, record: InternalRow, dataType: DataType, parseObjectArrayItem: Boolean, propertyName: String, collectionConfig: CollectionConfig): AnyRef = {
@@ -216,7 +215,7 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
216215
})
217216
}
218217
objList.asJava
219-
case default => throw new SparkDataTypeNotSupported(s"DataType ${default} is not supported by Weaviate")
218+
case default => throw SparkDataTypeNotSupported(s"DataType ${default} is not supported by Weaviate")
220219
}
221220
}
222221

src/main/scala/WeaviateOptions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class WeaviateOptions(config: CaseInsensitiveStringMap) extends Serializable {
9090
if (oidcUsername.trim().nonEmpty && oidcPassword.trim().nonEmpty) {
9191
config.authentication(Authentication.resourceOwnerPassword(oidcUsername, oidcPassword, oidcScopes.asJava))
9292
} else if (oidcClientSecret.trim().nonEmpty) {
93-
config.authentication(Authentication.clientCredentials(oidcClientId, oidcClientSecret, oidcScopes.asJava))
93+
config.authentication(Authentication.clientCredentials(oidcClientSecret, oidcScopes.asJava))
9494
} else if (oidcAccessToken.trim().nonEmpty) {
9595
config.authentication(Authentication.bearerToken(oidcAccessToken, oidcRefreshToken, oidcAccessTokenLifetime))
9696
} else if (apiKey.trim().nonEmpty) {

src/test/python/test_spark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_connector_version():
3333

3434
connector_version = os.environ.get("CONNECTOR_VERSION", get_connector_version())
3535
scala_version = os.environ.get("SCALA_VERSION", "2.13")
36-
weaviate_version = os.environ.get("WEAVIATE_VERSION", "1.30.3")
36+
weaviate_version = os.environ.get("WEAVIATE_VERSION", "1.32.17")
3737
spark_connector_jar_path = os.environ.get(
3838
"CONNECTOR_JAR_PATH", f"target/scala-{scala_version}/spark-connector-assembly-{connector_version}.jar"
3939
)
@@ -69,7 +69,7 @@ def weaviate_client():
6969
"CLUSTER_HOSTNAME": "node1",
7070
"PERSISTENCE_DATA_PATH": "./data"},
7171
)
72-
time.sleep(0.5)
72+
time.sleep(2)
7373
wclient = weaviate.Client('http://localhost:8080')
7474
test_class_name = "TestWillBeRemoved"
7575
retries = 3

src/test/scala/SparkIntegrationTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import org.apache.spark.sql.execution.streaming.MemoryStream
55
import org.apache.spark.sql.{AnalysisException, DataFrame, Encoder, Encoders}
66
import org.scalatest.BeforeAndAfter
77
import org.scalatest.funsuite.AnyFunSuite
8-
import io.weaviate.client6.v1.api.collections.{Property}
8+
import io.weaviate.client6.v1.api.collections.Property
99
import io.weaviate.client6.v1.internal.ObjectBuilder
1010
import org.apache.spark.SparkException
1111

src/test/scala/TestWeaviateDataWriter.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class TestWeaviateDataWriter extends AnyFunSuite {
2929
assert(weaviateObject.properties().get("title").equals("Sam"))
3030
assert(weaviateObject.properties().get("content") == "Sam")
3131
assert(weaviateObject.properties().get("wordCount") == 5)
32-
// how to get tenant?
33-
// assert(weaviateObject. == "TenantA")
32+
assert(weaviateObject.tenant() == "TenantA")
3433
}
3534

3635
test("Test Build Weaviate Object without supplied ID") {
@@ -52,7 +51,7 @@ class TestWeaviateDataWriter extends AnyFunSuite {
5251
assert(weaviateObject.properties().get("content") == "Sam")
5352
assert(weaviateObject.properties().get("wordCount") == 5)
5453
assert(weaviateObject.uuid() != null)
55-
// assert(weaviateObject.getTenant == null)
54+
assert(weaviateObject.tenant() == null)
5655
}
5756

5857
test("Test Build Weaviate Object with supplied ID") {
@@ -78,7 +77,7 @@ class TestWeaviateDataWriter extends AnyFunSuite {
7877
assert(weaviateObject.properties().get("content") == "Sam")
7978
assert(weaviateObject.properties().get("wordCount") == 5)
8079
assert(weaviateObject.uuid() == uuid)
81-
// assert(weaviateObject.getTenant == null)
80+
assert(weaviateObject.tenant() == null)
8281
}
8382

8483
test("Test Build Weaviate Object with DateString") {
@@ -103,7 +102,7 @@ class TestWeaviateDataWriter extends AnyFunSuite {
103102
assert(weaviateObject.properties().get("content") == "Sam")
104103
assert(weaviateObject.properties().get("wordCount") == 5)
105104
assert(weaviateObject.properties().get("date") == "2022-11-18T00:00:00Z")
106-
// assert(weaviateObject.getTenant == null)
105+
assert(weaviateObject.tenant() == null)
107106
}
108107

109108
test("Test Build Weaviate Object with Unsupported Data types") {
@@ -180,9 +179,9 @@ class TestWeaviateDataWriter extends AnyFunSuite {
180179
assert(weaviateObject.properties().get("content") == "Sam")
181180
assert(weaviateObject.properties().get("wordCount") == 5)
182181
assert(weaviateObject.uuid() == uuid)
183-
assert(weaviateObject.metadata().vectors() != null)
184-
assert(weaviateObject.metadata().vectors().getDefaultSingle().sameElements(embedding))
185-
// assert(weaviateObject.getTenant == null)
182+
assert(weaviateObject.vectors() != null)
183+
assert(weaviateObject.vectors().getDefaultSingle().sameElements(embedding))
184+
assert(weaviateObject.tenant() == null)
186185
}
187186

188187
test("Test Build Weaviate Object with vectors") {
@@ -214,10 +213,10 @@ class TestWeaviateDataWriter extends AnyFunSuite {
214213
assert(weaviateObject.properties().get("content") == "Sam")
215214
assert(weaviateObject.properties().get("wordCount") == 5)
216215
assert(weaviateObject.uuid() == uuid)
217-
assert(weaviateObject.metadata().vectors() != null)
218-
assert(weaviateObject.metadata().vectors().getSingle("v1").sameElements(embedding1))
219-
assert(weaviateObject.metadata().vectors().getSingle("v2").sameElements(embedding2))
220-
// assert(weaviateObject.getTenant == null)
216+
assert(weaviateObject.vectors() != null)
217+
assert(weaviateObject.vectors().getSingle("v1").sameElements(embedding1))
218+
assert(weaviateObject.vectors().getSingle("v2").sameElements(embedding2))
219+
assert(weaviateObject.tenant() == null)
221220
}
222221

223222
test("Test Build Weaviate Object with vectors and multi vectors") {
@@ -250,12 +249,12 @@ class TestWeaviateDataWriter extends AnyFunSuite {
250249
assert(weaviateObject.properties().get("content") == "Sam")
251250
assert(weaviateObject.properties().get("wordCount") == 5)
252251
assert(weaviateObject.uuid() == uuid)
253-
assert(weaviateObject.metadata().vectors() != null)
254-
assert(weaviateObject.metadata().vectors().getSingle("v1").sameElements(embedding1))
255-
assert(weaviateObject.metadata().vectors().getMulti("colbert").length == 2)
256-
assert(weaviateObject.metadata().vectors().getMulti("colbert")(0).sameElements(colbert(0)))
257-
assert(weaviateObject.metadata().vectors().getMulti("colbert")(1).sameElements(colbert(1)))
258-
// assert(weaviateObject.getTenant == null)
252+
assert(weaviateObject.vectors() != null)
253+
assert(weaviateObject.vectors().getSingle("v1").sameElements(embedding1))
254+
assert(weaviateObject.vectors().getMulti("colbert").length == 2)
255+
assert(weaviateObject.vectors().getMulti("colbert")(0).sameElements(colbert(0)))
256+
assert(weaviateObject.vectors().getMulti("colbert")(1).sameElements(colbert(1)))
257+
assert(weaviateObject.tenant() == null)
259258
}
260259

261260
test("Test Build Weaviate Object with geo coordinates") {
@@ -287,6 +286,6 @@ class TestWeaviateDataWriter extends AnyFunSuite {
287286
assert(weaviateObject.uuid() == uuid)
288287
assert(weaviateObject.properties().get("title") == "title")
289288
assert(weaviateObject.properties().get("geo") != null)
290-
// assert(weaviateObject.getTenant == null)
289+
assert(weaviateObject.tenant() == null)
291290
}
292291
}

src/test/scala/WeaviateDocker.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object WeaviateDocker {
2525
var retries = 10
2626

2727
def start(vectorizerModule: String = "none", enableModules: String = "text2vec-openai"): Int = {
28-
val weaviateVersion = "1.30.6"
28+
val weaviateVersion = "1.34.0"
2929
val docker_run =
3030
s"""docker run -d --name=weaviate-test-container-will-be-deleted
3131
-p 8080:8080
@@ -64,6 +64,7 @@ semitechnologies/weaviate:$weaviateVersion"""
6464
}.getOrElse(false)
6565
}
6666

67+
Thread.sleep(2000L)
6768
val maxAttempts = 10
6869
for (_ <- 1 to maxAttempts) {
6970
if (checkReadinessProbe) {

0 commit comments

Comments
 (0)