@@ -2,12 +2,12 @@ package io.weaviate.spark
22
33import com .google .gson .reflect .TypeToken
44import com .google .gson .{Gson , JsonSyntaxException }
5+ import io .weaviate .client6 .v1 .api .collections .data .Reference
56import org .apache .spark .internal .Logging
67import org .apache .spark .sql .catalyst .InternalRow
78import org .apache .spark .sql .connector .write .{DataWriter , WriterCommitMessage }
89import org .apache .spark .sql .types ._
9- import io .weaviate .client .v1 .data .model .WeaviateObject
10- import io .weaviate .client .v1 .schema .model .WeaviateClass
10+ import io .weaviate .client6 .v1 .api .collections .{CollectionConfig , ObjectMetadata , Vectors , WeaviateObject }
1111import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
1212
1313import java .util .{Map => JavaMap }
@@ -20,12 +20,12 @@ case class WeaviateCommitMessage(msg: String) extends WriterCommitMessage
2020
2121case class WeaviateDataWriter (weaviateOptions : WeaviateOptions , schema : StructType )
2222 extends DataWriter [InternalRow ] with Serializable with Logging {
23- var batch = mutable.Map [String , WeaviateObject ]()
24- private val weaviateClass = weaviateOptions.getWeaviateClass ()
23+ var batch = mutable.Map [String , WeaviateObject [ JavaMap [ String , Object ], Reference , ObjectMetadata ] ]()
24+ private lazy val weaviateClass = weaviateOptions.getCollectionConfig ()
2525
2626 override def write (record : InternalRow ): Unit = {
2727 val weaviateObject = buildWeaviateObject(record, weaviateClass)
28- batch += (weaviateObject.getId -> weaviateObject)
28+ batch += (weaviateObject.uuid() -> weaviateObject)
2929
3030 if (batch.size >= weaviateOptions.batchSize) writeBatch()
3131 }
@@ -36,38 +36,35 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
3636 val consistencyLevel = weaviateOptions.consistencyLevel
3737 val client = weaviateOptions.getClient()
3838
39- val results = if (consistencyLevel != " " ) {
40- logInfo(s " Writing using consistency level: ${consistencyLevel}" )
41- client.batch().objectsBatcher().withObjects(batch.values.toList: _* ).withConsistencyLevel(consistencyLevel).run()
42- } else {
43- client.batch().objectsBatcher().withObjects(batch.values.toList: _* ).run()
44- }
39+ val collection = client.collections.use(weaviateOptions.className).withTenant(weaviateOptions.tenant).withConsistencyLevel(consistencyLevel)
40+
41+ val results = collection.data.insertMany(batch.values.toList.asJava)
4542
4643 val IDs = batch.keys.toList
4744
48- if (results.hasErrors || results.getResult == null ) {
45+ if (results.errors() != null && ! results.errors().isEmpty ) {
4946 if (retries == 0 ) {
5047 throw WeaviateResultError (s " error getting result and no more retries left. " +
51- s " Error from Weaviate: ${results.getError.getMessages }" )
48+ s " Error from Weaviate: ${results.errors().asScala.mkString( " , " ) }" )
5249 }
5350 if (retries > 0 ) {
54- logError(s " batch error: ${results.getError.getMessages }, will retry " )
51+ logError(s " batch error: ${results.errors().asScala.mkString( " , " ) }, will retry " )
5552 logInfo(s " Retrying batch in ${weaviateOptions.retriesBackoff} seconds. Batch has following IDs: ${IDs }" )
5653 Thread .sleep(weaviateOptions.retriesBackoff * 1000 )
5754 writeBatch(retries - 1 )
5855 }
5956 } else {
60- val (objectsWithSuccess, objectsWithError) = results.getResult. partition(_.getResult.getErrors == null )
57+ val (objectsWithSuccess, objectsWithError) = results.responses().asScala. partition(_.error() == null )
6158 if (objectsWithError.size > 0 && retries > 0 ) {
62- val errors = objectsWithError.map(obj => s " ${obj.getId }: ${obj.getResult.getErrors.toString }" )
63- val successIDs = objectsWithSuccess.map(_.getId ).toList
59+ val errors = objectsWithError.map(obj => s " ${obj.uuid() }: ${obj.error() }" )
60+ val successIDs = objectsWithSuccess.map(_.uuid() ).toList
6461 logWarning(s " Successfully imported ${successIDs}. " +
6562 s " Retrying objects with an error. Following objects in the batch upload had an error: ${errors.mkString(" Array(" , " , " , " )" )}" )
6663 batch = batch -- successIDs
6764 writeBatch(retries - 1 )
6865 } else if (objectsWithError.size > 0 ) {
69- val errorIds = objectsWithError.map(obj => obj.getId )
70- val errorMessages = objectsWithError.map(obj => obj.getResult.getErrors.toString ).distinct
66+ val errorIds = objectsWithError.map(obj => obj.uuid() )
67+ val errorMessages = objectsWithError.map(obj => obj.error() ).distinct
7168 throw WeaviateResultError (s " Error writing to weaviate and no more retries left. " +
7269 s " IDs with errors: ${errorIds.mkString(" Array(" , " , " , " )" )}. " +
7370 s " Error messages: ${errorMessages.mkString(" Array(" , " , " , " )" )}" )
@@ -79,17 +76,19 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
7976 }
8077 }
8178
82- private [spark] def buildWeaviateObject (record : InternalRow , weaviateClass : WeaviateClass = null ): WeaviateObject = {
83- var builder = WeaviateObject .builder.className(weaviateOptions.className)
84- if (weaviateOptions.tenant != null ) {
85- builder = builder.tenant(weaviateOptions.tenant)
86- }
79+ private [spark] def buildWeaviateObject (record : InternalRow , collectionConfig : CollectionConfig = null ): WeaviateObject [java.util.Map [String , Object ], Reference , ObjectMetadata ] = {
80+ var builder : WeaviateObject .Builder [java.util.Map [String , Object ], Reference , ObjectMetadata ] = new WeaviateObject .Builder ()
81+
82+ builder = builder.collection(weaviateOptions.className)
83+
84+ var id : String = null
8785 val properties = mutable.Map [String , AnyRef ]()
86+ var vector : Array [Float ] = null
8887 val vectors = mutable.Map [String , Array [Float ]]()
8988 val multiVectors = mutable.Map [String , Array [Array [Float ]]]()
9089 schema.zipWithIndex.foreach(field =>
9190 field._1.name match {
92- case weaviateOptions.vector => builder = builder.vector( record.getArray(field._2).toArray(FloatType ) )
91+ case weaviateOptions.vector => vector = record.getArray(field._2).toArray(FloatType )
9392 case key if weaviateOptions.vectors.contains(key) => vectors += (weaviateOptions.vectors(key) -> record.getArray(field._2).toArray(FloatType ))
9493 case key if weaviateOptions.multiVectors.contains(key) => {
9594 val multiVectorArrayData = record.get(field._2, ArrayType (ArrayType (FloatType ))) match {
@@ -105,31 +104,43 @@ case class WeaviateDataWriter(weaviateOptions: WeaviateOptions, schema: StructTy
105104
106105 multiVectors += (weaviateOptions.multiVectors(key) -> multiVector)
107106 }
108- case weaviateOptions.id => builder = builder.id(record.getString(field._2))
109- case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false , field._1.name, weaviateClass )
107+ case weaviateOptions.id => id = record.getString(field._2) // builder.id(record.getString(field._2))
108+ case _ => properties(field._1.name) = getPropertyValue(field._2, record, field._1.dataType, false , field._1.name, collectionConfig )
110109 }
111110 )
111+ val metadata = new ObjectMetadata .Builder ()
112+
112113 if (weaviateOptions.id == null ) {
113- builder.id(java.util.UUID .randomUUID.toString)
114+ metadata.uuid(java.util.UUID .randomUUID.toString)
115+ } else {
116+ metadata.uuid(id)
114117 }
115118
119+ val allvectors = ListBuffer .empty[Vectors ]
120+ if (vector != null ) {
121+ allvectors += Vectors .of(vector)
122+ }
116123 if (vectors.nonEmpty) {
117- builder.vectors(vectors.map { case (key, arr) => key -> arr.map(Float .box) }.asJava)
124+ val arr = vectors.map { case (key, arr) => Vectors .of(key, arr) }.toArray
125+ allvectors ++= arr
118126 }
119127 if (multiVectors.nonEmpty) {
120- builder.multiVectors(multiVectors.map { case (key, multiVector) => key -> multiVector.map { vec => { vec.map(Float .box) }} }.toMap.asJava)
128+ val arr = multiVectors.map { case (key, multiVector) => Vectors .of(key, multiVector) }.toArray
129+ allvectors ++= arr
121130 }
122- builder.properties(properties.asJava).build
131+ metadata.vectors(allvectors.toSeq : _* )
132+
133+ builder.properties(properties.asJava).metadata(metadata.build()).build()
123134 }
124135
125- def getPropertyValue (index : Int , record : InternalRow , dataType : DataType , parseObjectArrayItem : Boolean , propertyName : String , weaviateClass : WeaviateClass ): AnyRef = {
136+ def getPropertyValue (index : Int , record : InternalRow , dataType : DataType , parseObjectArrayItem : Boolean , propertyName : String , collectionConfig : CollectionConfig ): AnyRef = {
126137 val valueFromField = getValueFromField(index, record, dataType, parseObjectArrayItem)
127- if (weaviateClass != null ) {
138+ if (collectionConfig != null ) {
128139 var dt = " "
129- weaviateClass.getProperties .forEach(p => {
130- if (p.getName == propertyName) {
140+ collectionConfig.properties() .forEach(p => {
141+ if (p.propertyName() == propertyName) {
131142 // we are just looking for geoCoordinates or phoneNumber type
132- dt = p.getDataType .get(0 )
143+ dt = p.dataTypes() .get(0 )
133144 }
134145 })
135146 if ((dt == " geoCoordinates" || dt == " phoneNumber" ) && valueFromField.isInstanceOf [String ]) {
0 commit comments