@@ -2,12 +2,12 @@ package io.weaviate.spark
22
33import com .google .gson .reflect .TypeToken
44import com .google .gson .{Gson , JsonSyntaxException }
5- import io .weaviate .client6 .v1 .api .collections .data . Reference
5+ import io .weaviate .client6 .v1 .api .collections .WeaviateObject
66import org .apache .spark .internal .Logging
77import org .apache .spark .sql .catalyst .InternalRow
88import org .apache .spark .sql .connector .write .{DataWriter , WriterCommitMessage }
99import org .apache .spark .sql .types ._
10- import io .weaviate .client6 .v1 .api .collections .{CollectionConfig , ObjectMetadata , Vectors , WeaviateObject }
10+ import io .weaviate .client6 .v1 .api .collections .{CollectionConfig , Vectors }
1111import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
1212
1313import java .util .{Map => JavaMap }
@@ -20,23 +20,25 @@ case class WeaviateCommitMessage(msg: String) extends WriterCommitMessage
2020
2121case class WeaviateDataWriter (weaviateOptions : WeaviateOptions , schema : StructType )
2222 extends DataWriter [InternalRow ] with Serializable with Logging {
23- var batch = mutable.Map [String , WeaviateObject [JavaMap [String , Object ], Reference , ObjectMetadata ]]()
23+ var batch = mutable.Map [String , WeaviateObject [JavaMap [String , Object ]]]()
2424 private lazy val weaviateClass = weaviateOptions.getCollectionConfig()
2525
2626 override def write (record : InternalRow ): Unit = {
2727 val weaviateObject = buildWeaviateObject(record, weaviateClass)
28- batch += (weaviateObject.uuid() -> weaviateObject)
28+ batch += (weaviateObject.uuid -> weaviateObject)
2929
3030 if (batch.size >= weaviateOptions.batchSize) writeBatch()
3131 }
3232
3333 def writeBatch (retries : Int = weaviateOptions.retries): Unit = {
3434 if (batch.isEmpty) return
3535
36- val consistencyLevel = weaviateOptions.consistencyLevel
3736 val client = weaviateOptions.getClient()
3837
39- val collection = client.collections.use(weaviateOptions.className).withTenant(weaviateOptions.tenant).withConsistencyLevel(consistencyLevel)
38+ val collection = client.collections
39+ .use(weaviateOptions.className)
40+ .withTenant(weaviateOptions.tenant)
41+ .withConsistencyLevel(weaviateOptions.consistencyLevel)
4042
4143 val results = collection.data.insertMany(batch.values.toList.asJava)
4244
@@ -76,11 +78,9 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
7678 }
7779 }
7880
79- private [spark] def buildWeaviateObject (record : InternalRow , collectionConfig : CollectionConfig = null ): WeaviateObject [java.util.Map [String , Object ], Reference , ObjectMetadata ] = {
80- var builder : WeaviateObject .Builder [java.util.Map [String , Object ], Reference , ObjectMetadata ] = new WeaviateObject .Builder ()
81- builder = builder.collection(weaviateOptions.className)
81+ private [spark] def buildWeaviateObject (record : InternalRow , collectionConfig : CollectionConfig = null ): WeaviateObject [java.util.Map [String , Object ]] = {
82+ val builder : WeaviateObject .Builder [java.util.Map [String , Object ]] = new WeaviateObject .Builder ()
8283
83- val metadata = new ObjectMetadata .Builder ()
8484 val properties = mutable.Map [String , AnyRef ]()
8585 var vector : Array [Float ] = null
8686 val vectors = mutable.Map [String , Array [Float ]]()
@@ -103,13 +103,13 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
103103
104104 multiVectors += (weaviateOptions.multiVectors(key) -> multiVector)
105105 }
106- case weaviateOptions.id => metadata .uuid(record.getString(field._2))
106+ case weaviateOptions.id => builder .uuid(record.getString(field._2))
107107 case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false , field._1.name, collectionConfig)
108108 }
109109 )
110110
111111 if (weaviateOptions.id == null ) {
112- metadata .uuid(java.util.UUID .randomUUID.toString)
112+ builder .uuid(java.util.UUID .randomUUID.toString)
113113 }
114114
115115 val allVectors = ListBuffer .empty[Vectors ]
@@ -124,9 +124,8 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
124124 val arr = multiVectors.map { case (key, multiVector) => Vectors .of(key, multiVector) }.toArray
125125 allVectors ++= arr
126126 }
127- metadata.vectors(allVectors.toSeq : _* )
128127
129- builder.properties(properties.asJava).metadata(metadata.build() ).build()
128+ builder.tenant(weaviateOptions.tenant). properties(properties.asJava).vectors(allVectors.toSeq : _* ).build()
130129 }
131130
132131 def getPropertyValue (index : Int , record : InternalRow , dataType : DataType , parseObjectArrayItem : Boolean , propertyName : String , collectionConfig : CollectionConfig ): AnyRef = {
@@ -216,7 +215,7 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
216215 })
217216 }
218217 objList.asJava
219- case default => throw new SparkDataTypeNotSupported (s " DataType ${default} is not supported by Weaviate " )
218+ case default => throw SparkDataTypeNotSupported (s " DataType ${default} is not supported by Weaviate " )
220219 }
221220 }
222221
0 commit comments